diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index 3e2f87db..8ba617a9 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel( int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); - int start_k_tile = blockIdx.x * block_k_tiles; + auto start_k_tile = blockIdx.x * block_k_tiles; if (start_k_tile >= k_tiles) { return; } @@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel( int4* sh_ptr = sh + stage_size * pipe; if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; @@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel( return; } - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; if (warp_id >= 4) { return; @@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { m.impl("awq_marlin_repack", &awq_marlin_repack_meta); -} \ No newline at end of file +} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index dafab501..14d397d0 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int lda, int block_rows) { - int start_row = block_rows * blockIdx.x; + auto start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { finish_row = size_m; @@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int base_k = 0; for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; + auto cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; @@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, if (rest) { if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; + auto cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; @@ -723,8 +723,8 @@ __global__ void Marlin( (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; @@ -743,7 +743,7 @@ __global__ void Marlin( s_sh_stride * slice_col + threadIdx.x; } } - int s_sh_wr = threadIdx.x; + auto s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // Zero-points @@ -756,7 +756,7 @@ __global__ void Marlin( zp_sh_stride * slice_col + threadIdx.x; } } - int zp_sh_wr = threadIdx.x; + auto zp_sh_wr = threadIdx.x; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; // We use a different scale layout for grouped and column-wise quantization as @@ -1047,7 +1047,7 @@ __global__ void Marlin( int4* sh_s_stage = sh_s + s_sh_stage * pipe; reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1085,7 +1085,7 @@ __global__ void Marlin( // Determine "position" inside the thread-block (based on warp and // thread-id) - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N @@ -1094,7 +1094,7 @@ __global__ void Marlin( cur_k += warp_row * 16; - int th_id = threadIdx.x % 32; + auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = @@ -1159,7 +1159,7 @@ __global__ void Marlin( (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1197,7 +1197,7 @@ __global__ void Marlin( (pipe / (group_blocks / thread_k_blocks))); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1323,7 +1323,7 @@ __global__ void Marlin( auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; + auto red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + @@ -1390,7 +1390,7 @@ __global__ void Marlin( 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; + auto c_sh_wr = threadIdx.x; int row = (threadIdx.x % 32) / 4; diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 5cd07855..7c2d089a 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel( int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); - int start_k_tile = blockIdx.x * block_k_tiles; + auto start_k_tile = blockIdx.x * block_k_tiles; if (start_k_tile >= k_tiles) { return; } @@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel( if constexpr (has_perm) { if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); @@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel( } else { if (threadIdx.x < stage_size) { - int k_id = threadIdx.x / stage_n_threads; - int n_id = threadIdx.x % stage_n_threads; + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; int first_k_packed = first_k / pack_factor; @@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel( return; } - int warp_id = threadIdx.x / 32; - int th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; if (warp_id >= 4) { return; @@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta); -} \ No newline at end of file +} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 4db8f5dc..ba0a2410 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -277,12 +277,12 @@ __global__ void Marlin( b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; + auto b_sh_wr = threadIdx.x; + auto b_sh_rd = threadIdx.x; int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; + auto s_sh_wr = threadIdx.x; int s_sh_rd; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in @@ -455,7 +455,7 @@ __global__ void Marlin( auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; + auto red_idx = threadIdx.x / b_sh_stride; constexpr int red_sh_stride = b_sh_stride * 4 * 2; constexpr int red_sh_delta = b_sh_stride; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + @@ -522,7 +522,7 @@ __global__ void Marlin( 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; + auto c_sh_wr = threadIdx.x; int row = (threadIdx.x % 32) / 4; diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu index 048a3f73..cd183076 100644 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu @@ -353,10 +353,10 @@ __global__ void Marlin( b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; + auto b_sh_wr = threadIdx.x; + auto b_sh_rd = threadIdx.x; - int s_tok_gl_rd = threadIdx.x; + auto s_tok_gl_rd = threadIdx.x; // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as @@ -368,8 +368,8 @@ __global__ void Marlin( int s_tok_sh_rd = (threadIdx.x % 32) / 4; bool s_tok_sh_wr_pred = threadIdx.x < prob_m; - int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - int s_ch_sh_wr = threadIdx.x; + auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; + auto s_ch_sh_wr = threadIdx.x; int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + 2 * ((threadIdx.x % 32) % 4); bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; @@ -558,7 +558,7 @@ __global__ void Marlin( auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; + auto red_idx = threadIdx.x / b_sh_stride; constexpr int red_sh_stride = b_sh_stride * 4 * 2; constexpr int red_sh_delta = b_sh_stride; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + @@ -628,7 +628,7 @@ __global__ void Marlin( 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; c_gl_wr += (4 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads * 2; - int c_sh_wr = 2 * threadIdx.x; + auto c_sh_wr = 2 * threadIdx.x; int row = (threadIdx.x % 32) / 4; diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 17837351..c33e71ae 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -273,15 +273,15 @@ __global__ void Marlin_24( (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + (threadIdx.x % (m_sh_stride)); m_gl_rd += (m_sh_stride)*slice_col; m_gl_rd += m_gl_rd_delta_o * slice_row; - int m_sh_wr = threadIdx.x; - int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + auto m_sh_wr = threadIdx.x; + auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; int s_gl_rd; if constexpr (group_blocks == -1) { @@ -291,7 +291,7 @@ __global__ void Marlin_24( s_sh_stride * slice_col + threadIdx.x; } - int s_sh_wr = threadIdx.x; + auto s_sh_wr = threadIdx.x; int s_sh_rd; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in @@ -516,7 +516,7 @@ __global__ void Marlin_24( auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; + auto red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + @@ -583,7 +583,7 @@ __global__ void Marlin_24( 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; + auto c_sh_wr = threadIdx.x; int col = 2 * ((threadIdx.x % 32) % 4); diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index c500d00e..8ab2af22 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -284,18 +284,18 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( int max_ctx_blocks, const float* k_scale, const float* v_scale) { // clang-format on constexpr int NWARPS = NUM_THREADS / WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; + const auto warpid = threadIdx.x / WARP_SIZE; + const auto laneid = threadIdx.x % WARP_SIZE; const int lane4id = laneid % 4; const int lane16id = laneid % 16; const int rowid = laneid / 16; - const int seq_idx = blockIdx.x; - const int partition_idx = blockIdx.y; + const auto seq_idx = blockIdx.x; + const auto partition_idx = blockIdx.y; constexpr int T_PAR_SIZE = 256; // token partition size set to 256 - const int max_num_partitions = gridDim.y; + const auto max_num_partitions = gridDim.y; const int context_len = context_lens[seq_idx]; @@ -346,9 +346,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // can be interpreted as B8x16 for 8 bit types _B16x8 Klocal[TLOOP][QKHELOOP]; - const int wg_start_head_idx = blockIdx.z * GQA_RATIO; - const int wg_start_kv_head_idx = blockIdx.z; - const int total_num_heads = gridDim.z * GQA_RATIO; + const auto wg_start_head_idx = blockIdx.z * GQA_RATIO; + const auto wg_start_kv_head_idx = blockIdx.z; + const auto total_num_heads = gridDim.z * GQA_RATIO; // for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps // each mfma takes QH16xT16x16HE across warp @@ -789,14 +789,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( int max_ctx_blocks, const float* k_scale, const float* v_scale) { // clang-format on constexpr int NWARPS = NUM_THREADS / WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; + const auto warpid = threadIdx.x / WARP_SIZE; + const auto laneid = threadIdx.x % WARP_SIZE; const int lane4id = laneid % 4; - const int seq_idx = blockIdx.x; - const int partition_idx = blockIdx.y; - const int partition_size = blockDim.x; - const int max_num_partitions = gridDim.y; + const auto seq_idx = blockIdx.x; + const auto partition_idx = blockIdx.y; + const auto partition_size = blockDim.x; + const auto max_num_partitions = gridDim.y; const int context_len = context_lens[seq_idx]; const int partition_start_token_idx = partition_idx * partition_size; @@ -838,8 +838,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( qk_max[h] = -FLT_MAX; } - const int wg_start_head_idx = blockIdx.z * GQA_RATIO; - const int wg_start_kv_head_idx = blockIdx.z; + const auto wg_start_head_idx = blockIdx.z * GQA_RATIO; + const auto wg_start_kv_head_idx = blockIdx.z; const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE; @@ -857,7 +857,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; // token id within partition - const int local_token_idx = threadIdx.x; + const auto local_token_idx = threadIdx.x; // token id within sequence const int global_token_idx = partition_start_token_idx + local_token_idx; @@ -1126,7 +1126,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( __syncthreads(); - const int num_heads = gridDim.z * GQA_RATIO; + const auto num_heads = gridDim.z * GQA_RATIO; float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = @@ -1268,14 +1268,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; + const auto warpid = threadIdx.x / WARP_SIZE; + [[maybe_unused]] const auto laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; // max num partitions supported is warp_size * NPAR_LOOPS @@ -1294,7 +1294,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #pragma unroll for (int i = 0; i < NPAR_LOOPS; i++) { - const int partition_no = i * WARP_SIZE + threadIdx.x; + const auto partition_no = i * WARP_SIZE + threadIdx.x; valid_partition[i] = (partition_no < num_partitions) ? partition_no : last_valid_partition; } @@ -1324,7 +1324,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } #pragma unroll for (int i = 0; i < NPAR_LOOPS; i++) { - const int partition_no = i * WARP_SIZE + threadIdx.x; + const auto partition_no = i * WARP_SIZE + threadIdx.x; rescaled_exp_sum[i] *= (partition_no < num_partitions) ? expf(reg_max_logit[i] - max_logit) : 0.0f; @@ -1336,7 +1336,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } #pragma unroll for (int i = 0; i < NPAR_LOOPS; i++) { - const int partition_no = i * WARP_SIZE + threadIdx.x; + const auto partition_no = i * WARP_SIZE + threadIdx.x; shared_exp_sums[partition_no] = rescaled_exp_sum[i]; }