Resolve race conditions in Marlin kernel (#11493)

Signed-off-by: wchen61 <wchen61@foxmail.com>
This commit is contained in:
wchen61 2025-01-03 06:58:56 +08:00 committed by GitHub
parent 187e32997c
commit 5dba257506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -834,6 +834,7 @@ __global__ void Marlin(
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
@ -932,11 +933,11 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
// Only fetch scales if this tile starts a new group
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} else { } else {
@ -1038,9 +1039,7 @@ __global__ void Marlin(
// No act-order case // No act-order case
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage = int4* sh_s_stage = sh_s + s_sh_stage * pipe;
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else { } else {
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
@ -1339,15 +1338,15 @@ __global__ void Marlin(
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
float* c_rd = float* c_rd = reinterpret_cast<float*>(
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]); &sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k]; c_rd[k] + c_wr[k];
} }
sh[red_sh_wr] = sh_red[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
} }
} }
@ -1357,7 +1356,7 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < 4 * 2; i++) { for (int i = 0; i < 4 * 2; i++) {
float* c_rd = float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]); reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
@ -1397,7 +1396,7 @@ __global__ void Marlin(
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred( cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i], &sh_red[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)], c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
@ -1410,7 +1409,7 @@ __global__ void Marlin(
for (int i = 0; i < thread_m_blocks * 4; i++) { for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) { if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>( reinterpret_cast<float*>(
@ -1461,10 +1460,10 @@ __global__ void Marlin(
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c); float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll #pragma unroll
for (int k = 0; k < th_size; k++) { for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] = sh_red[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]); float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
#pragma unroll #pragma unroll
for (int f = 0; f < 4; f++) { for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
@ -1515,7 +1514,7 @@ __global__ void Marlin(
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
} }
((scalar_t2*)sh)[idx] = res; ((scalar_t2*)sh_red)[idx] = res;
}; };
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
@ -1543,7 +1542,7 @@ __global__ void Marlin(
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) { i++) {
if (c_gl_wr < c_gl_wr_end) { if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd]; C[c_gl_wr] = sh_red[c_sh_rd];
c_gl_wr += c_gl_wr_delta; c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta; c_sh_rd += c_sh_rd_delta;
} }
@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
float pipe_size = (a_size + b_size) * pipe_stages; float pipe_size = (a_size + b_size) * pipe_stages;
float reduce_size = max(th_config.num_threads * 32 * 4,
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
} }
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,