Resolve race conditions in Marlin kernel (#11493)
Signed-off-by: wchen61 <wchen61@foxmail.com>
This commit is contained in:
parent
187e32997c
commit
5dba257506
@ -834,6 +834,7 @@ __global__ void Marlin(
|
|||||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||||
|
int4* sh_red = sh_s + (stages * s_sh_stage);
|
||||||
|
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
@ -932,11 +933,11 @@ __global__ void Marlin(
|
|||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
|
|
||||||
if constexpr (group_blocks >= thread_k_blocks) {
|
if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
// Only fetch scales if this tile starts a new group
|
|
||||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
|
||||||
if (s_sh_wr_pred) {
|
if (s_sh_wr_pred) {
|
||||||
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||||
}
|
}
|
||||||
|
// Only fetch scales if this tile starts a new group
|
||||||
|
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
|
||||||
s_gl_rd += s_gl_rd_delta;
|
s_gl_rd += s_gl_rd_delta;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -1038,9 +1039,7 @@ __global__ void Marlin(
|
|||||||
// No act-order case
|
// No act-order case
|
||||||
if constexpr (group_blocks != -1) {
|
if constexpr (group_blocks != -1) {
|
||||||
if constexpr (group_blocks >= thread_k_blocks) {
|
if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
int4* sh_s_stage =
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
|
||||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
int warp_id = threadIdx.x / 32;
|
||||||
@ -1339,15 +1338,15 @@ __global__ void Marlin(
|
|||||||
int red_sh_wr =
|
int red_sh_wr =
|
||||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||||
if (i < red_off) {
|
if (i < red_off) {
|
||||||
float* c_rd =
|
float* c_rd = reinterpret_cast<float*>(
|
||||||
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
&sh_red[red_sh_delta * j + red_sh_rd]);
|
||||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < 4; k++)
|
for (int k = 0; k < 4; k++)
|
||||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||||
c_rd[k] + c_wr[k];
|
c_rd[k] + c_wr[k];
|
||||||
}
|
}
|
||||||
sh[red_sh_wr] =
|
sh_red[red_sh_wr] =
|
||||||
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1357,7 +1356,7 @@ __global__ void Marlin(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4 * 2; i++) {
|
for (int i = 0; i < 4 * 2; i++) {
|
||||||
float* c_rd =
|
float* c_rd =
|
||||||
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||||
@ -1397,7 +1396,7 @@ __global__ void Marlin(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
cp_async4_pred(
|
cp_async4_pred(
|
||||||
&sh[c_sh_wr + c_sh_wr_delta * i],
|
&sh_red[c_sh_wr + c_sh_wr_delta * i],
|
||||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||||
c_gl_wr_delta_i * (i % 2)],
|
c_gl_wr_delta_i * (i % 2)],
|
||||||
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||||
@ -1410,7 +1409,7 @@ __global__ void Marlin(
|
|||||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||||
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
||||||
if (!first) {
|
if (!first) {
|
||||||
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 2 * 4; j++) {
|
for (int j = 0; j < 2 * 4; j++) {
|
||||||
reinterpret_cast<float*>(
|
reinterpret_cast<float*>(
|
||||||
@ -1461,10 +1460,10 @@ __global__ void Marlin(
|
|||||||
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < th_size; k++) {
|
for (int k = 0; k < th_size; k++) {
|
||||||
sh[threadIdx.x] =
|
sh_red[threadIdx.x] =
|
||||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
||||||
|
|
||||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int f = 0; f < 4; f++) {
|
for (int f = 0; f < 4; f++) {
|
||||||
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
||||||
@ -1515,7 +1514,7 @@ __global__ void Marlin(
|
|||||||
res = __hmul2(res, s[0]);
|
res = __hmul2(res, s[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
((scalar_t2*)sh)[idx] = res;
|
((scalar_t2*)sh_red)[idx] = res;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
@ -1543,7 +1542,7 @@ __global__ void Marlin(
|
|||||||
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||||
i++) {
|
i++) {
|
||||||
if (c_gl_wr < c_gl_wr_end) {
|
if (c_gl_wr < c_gl_wr_end) {
|
||||||
C[c_gl_wr] = sh[c_sh_rd];
|
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||||
c_gl_wr += c_gl_wr_delta;
|
c_gl_wr += c_gl_wr_delta;
|
||||||
c_sh_rd += c_sh_rd_delta;
|
c_sh_rd += c_sh_rd_delta;
|
||||||
}
|
}
|
||||||
@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
|||||||
|
|
||||||
float pipe_size = (a_size + b_size) * pipe_stages;
|
float pipe_size = (a_size + b_size) * pipe_stages;
|
||||||
|
|
||||||
|
float reduce_size = max(th_config.num_threads * 32 * 4,
|
||||||
|
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
|
||||||
|
|
||||||
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
||||||
|
|
||||||
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user