Fix CUDA kernel index data type in vllm/csrc/quantization/fused_kernels/layernorm_utils.cuh +10 (#15159)
Signed-off-by: Lu Fang <lufang@fb.com> Co-authored-by: Richard Barnes <rbarnes@meta.com>
This commit is contained in:
parent
0cfe7d386d
commit
d3ccbd6350
@ -24,7 +24,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
// sum of squares
|
||||
float ss = 0.0f;
|
||||
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
@ -58,7 +58,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
@ -103,7 +103,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
@ -142,7 +142,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
@ -206,7 +206,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
@ -286,7 +286,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
||||
// replace scaled_fp8_conversion_vec
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> const in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
|
@ -101,10 +101,10 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int n = tid/32;
|
||||
const int l = tid - 32*n;
|
||||
const int is = 8*n + l/16;
|
||||
@ -123,10 +123,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
const int r = threadIdx.x/4;
|
||||
const auto r = threadIdx.x/4;
|
||||
const int tid = r/2;
|
||||
const int is0 = r%2;
|
||||
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||
@ -164,10 +164,10 @@ template<typename dst_t>
|
||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int is = 2*il;
|
||||
@ -197,10 +197,10 @@ template<typename dst_t>
|
||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/16; // il is in 0...3
|
||||
const int ir = tid%16; // ir is in 0...15
|
||||
const int is = 2*il; // is is in 0...6
|
||||
@ -231,10 +231,10 @@ template<typename dst_t>
|
||||
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int ip = tid/32; // ip is 0 or 1
|
||||
const int il = tid - 32*ip; // 0...32
|
||||
const int is = 8*ip + il/16;
|
||||
@ -256,10 +256,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -275,10 +275,10 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -293,10 +293,10 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -309,10 +309,10 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -332,10 +332,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -399,10 +399,10 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
@ -417,10 +417,10 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
@ -565,4 +565,4 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,11 +19,11 @@ template <typename scalar_t>
|
||||
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
|
||||
void* __restrict__ vy, const int kx,
|
||||
const int kx_padded) {
|
||||
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const auto ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ix >= kx_padded) {
|
||||
return;
|
||||
}
|
||||
const int iy = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const auto iy = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const int i_padded = iy * kx_padded + ix;
|
||||
|
||||
block_q8_1* y = (block_q8_1*)vy;
|
||||
|
@ -14,10 +14,10 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
|
||||
const int & ncols_dst = ncols_y;
|
||||
|
||||
const int row_dst_0 = blockIdx.x*mmq_y;
|
||||
const auto row_dst_0 = blockIdx.x*mmq_y;
|
||||
const int & row_x_0 = row_dst_0;
|
||||
|
||||
const int col_dst_0 = blockIdx.y*mmq_x;
|
||||
const auto col_dst_0 = blockIdx.y*mmq_x;
|
||||
const int & col_y_0 = col_dst_0;
|
||||
|
||||
int * tile_x_ql = nullptr;
|
||||
@ -39,7 +39,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
|
||||
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
||||
const auto kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
#pragma unroll
|
||||
@ -53,7 +53,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
#pragma unroll
|
||||
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
||||
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
|
||||
const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
|
||||
const auto kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
|
||||
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
|
||||
|
||||
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
||||
@ -87,14 +87,14 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
const int col_dst = col_dst_0 + j + threadIdx.y;
|
||||
const auto col_dst = col_dst_0 + j + threadIdx.y;
|
||||
if (col_dst >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
||||
const auto row_dst = row_dst_0 + threadIdx.x + i;
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const auto row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
@ -16,7 +16,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
@ -19,10 +19,10 @@ static __device__ __forceinline__ void moe_q(
|
||||
|
||||
const int ncols_dst = ncols_y * top_k;
|
||||
|
||||
const int row_dst_0 = blockIdx.x * mmq_y;
|
||||
const auto row_dst_0 = blockIdx.x * mmq_y;
|
||||
const int& row_x_0 = row_dst_0;
|
||||
|
||||
const int col_dst_0 = blockIdx.y * mmq_x;
|
||||
const auto col_dst_0 = blockIdx.y * mmq_x;
|
||||
|
||||
int token_offs[mmq_x / nwarps];
|
||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||
@ -56,7 +56,7 @@ static __device__ __forceinline__ void moe_q(
|
||||
const int n_per_r = ((qk * blocks_per_warp) / qr);
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) {
|
||||
const int kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
|
||||
const auto kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
#pragma unroll
|
||||
@ -73,7 +73,7 @@ static __device__ __forceinline__ void moe_q(
|
||||
}
|
||||
|
||||
if (threadIdx.x < n_per_r / QK8_1) {
|
||||
const int kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
|
||||
const auto kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
|
||||
const int col_y_eff = token_offs[threadIdx.y] / top_k;
|
||||
const int block_x =
|
||||
ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby;
|
||||
@ -119,7 +119,7 @@ static __device__ __forceinline__ void moe_q(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
||||
const auto row_dst = row_dst_0 + threadIdx.x + i;
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
|
@ -199,12 +199,12 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
@ -337,12 +337,12 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
|
||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
@ -458,12 +458,12 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
|
||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
@ -586,12 +586,12 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
|
||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
@ -765,14 +765,14 @@ __global__ void reconstruct_exllama_8bit_kernel(
|
||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -862,14 +862,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -967,14 +967,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
|
||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -1065,14 +1065,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
|
||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -1181,11 +1181,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
int zero_width = width / 8;
|
||||
int vec_height = height * 4;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||
auto h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
@ -1197,8 +1197,8 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
}
|
||||
|
||||
__shared__ half2 deq2[256][8];
|
||||
int val = threadIdx.x / 8;
|
||||
int off = threadIdx.x % 8;
|
||||
auto val = threadIdx.x / 8;
|
||||
auto off = threadIdx.x % 8;
|
||||
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
||||
deq2[val][off] =
|
||||
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
|
||||
@ -1280,11 +1280,11 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
||||
int zero_width = width / 4;
|
||||
int vec_height = height * 2;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
|
||||
auto h = BLOCK_KN_SIZE * blockIdx.z / 4;
|
||||
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
|
||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
@ -1393,8 +1393,8 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out) {
|
||||
// Start of block
|
||||
|
||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
int row = blockIdx.y * 32 / bit;
|
||||
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto row = blockIdx.y * 32 / bit;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
@ -1425,8 +1425,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
||||
const int height, const int width, const int group,
|
||||
half* __restrict__ out) {
|
||||
// Start of block
|
||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
int row = blockIdx.y * 32;
|
||||
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto row = blockIdx.y * 32;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
@ -1542,7 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
||||
|
||||
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1555,7 +1555,7 @@ __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
|
||||
__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1568,7 +1568,7 @@ __global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
|
||||
__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1581,7 +1581,7 @@ __global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
|
||||
__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1599,9 +1599,9 @@ __global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w,
|
||||
const uint64_t* w2 = (uint64_t*)w;
|
||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
auto w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
uint64_t dst = 0;
|
||||
|
||||
@ -1630,9 +1630,9 @@ __global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w,
|
||||
const uint64_t* w2 = (uint64_t*)w;
|
||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
auto w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 4;
|
||||
uint64_t dst = 0;
|
||||
|
||||
@ -1658,10 +1658,10 @@ __global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const int* __restrict__ q_perm,
|
||||
const int w_width) {
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w_column >= w_width) return;
|
||||
int w_new_row = blockIdx.y * 3;
|
||||
int q_perm_idx = blockIdx.y << 5;
|
||||
auto w_new_row = blockIdx.y * 3;
|
||||
auto q_perm_idx = blockIdx.y << 5;
|
||||
uint32_t dst[3] = {0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
@ -1744,9 +1744,9 @@ __global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w,
|
||||
const uint64_t* w2 = (uint64_t*)w;
|
||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
auto w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 2;
|
||||
uint64_t dst = 0;
|
||||
|
||||
|
@ -55,11 +55,11 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K +
|
||||
blockIdx.z * params.SplitK * 4;
|
||||
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const auto lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// For matrix A, a block load/store Mtile(row) x 32(col) elements in
|
||||
// multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter
|
||||
const int Aldg_row_base_idx = threadIdx.x / 4;
|
||||
const auto Aldg_row_base_idx = threadIdx.x / 4;
|
||||
Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A;
|
||||
const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx;
|
||||
|
||||
@ -67,7 +67,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
// elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row)
|
||||
// * 128(col) per iter
|
||||
Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B;
|
||||
const int Bldg_row_base_idx = threadIdx.x / 8;
|
||||
const auto Bldg_row_base_idx = threadIdx.x / 8;
|
||||
const int Bldg_base_offset =
|
||||
Bldg_row_base_idx * params.K * 4 + Bldg_col_idx;
|
||||
|
||||
@ -89,7 +89,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
B_ldg_guard = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) {
|
||||
int m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||
auto m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||
if (m_idx < params.M) {
|
||||
A_ldg_guard |= (1u << i);
|
||||
}
|
||||
@ -98,8 +98,8 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
const int N_padded = (params.N + 31) / 32 * 32;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) {
|
||||
int n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||
i * N_SIZE_ONE_LOAD;
|
||||
auto n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||
i * N_SIZE_ONE_LOAD;
|
||||
if (n_idx < N_padded) {
|
||||
B_ldg_guard |= (1u << i);
|
||||
}
|
||||
@ -355,7 +355,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
__device__ void fused_splitk_reduce() {
|
||||
// need splitk-reduce if enable splitk
|
||||
if (gridDim.z > 1) {
|
||||
int blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||
auto blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||
// Wait for all previous blocks in the splitk direction to accumulate the
|
||||
// results into C_tmp
|
||||
if (threadIdx.x == 0) {
|
||||
@ -371,7 +371,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||
auto C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||
if (blockIdx.z != 0) {
|
||||
// expecting that temporary register here reuses the previous A&B frag
|
||||
// register
|
||||
@ -456,7 +456,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
|
||||
FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset;
|
||||
// C_tile lds and stg
|
||||
int m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||
auto m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||
bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N;
|
||||
if (WARP_NTILE == 32) {
|
||||
int lds_c_base_offset = warp_id * Mtile * WARP_NTILE +
|
||||
@ -580,9 +580,9 @@ __global__ void __launch_bounds__(BLOCK)
|
||||
int sts_stage_idx = 0;
|
||||
int lds_stage_idx = 0;
|
||||
|
||||
int tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||
? params.SplitK
|
||||
: params.K - blockIdx.z * params.SplitK;
|
||||
auto tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||
? params.SplitK
|
||||
: params.K - blockIdx.z * params.SplitK;
|
||||
int k_tiles = (tb_k_slice + 31) / 32;
|
||||
int first_k_tile = tb_k_slice - (k_tiles - 1) * 32;
|
||||
|
||||
@ -777,13 +777,13 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
const QT* qdata, const FT* scales, const FT* zeros, FT* fdata,
|
||||
const int N_32align, const int N, const int K) {
|
||||
__shared__ FT smem[64 * 32];
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int lane_id = threadIdx.x % 32;
|
||||
const int src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
const auto src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||
const int src_col_idx =
|
||||
blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16;
|
||||
const int src_offset = src_row_idx * K * 4 + src_col_idx;
|
||||
int params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||
auto params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||
|
||||
QT qval_reg[16];
|
||||
const QT* pdata = qdata + src_offset;
|
||||
@ -829,8 +829,8 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
*reinterpret_cast<uint4*>(smem + lds_base_offset + i * 32 * 32);
|
||||
}
|
||||
|
||||
const int dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||
const int dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||
const auto dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||
const auto dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
int dst_row_kidx = dst_row_base_kidx + i * 32;
|
||||
@ -1008,4 +1008,4 @@ torch::Tensor allspark_w8a16_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm);
|
||||
}
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ __global__ void __launch_bounds__(128)
|
||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||
const int K, const int N, const int N_32align) {
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const auto lane_id = threadIdx.x % 32;
|
||||
const auto warp_id = threadIdx.x / 32;
|
||||
|
||||
if (blockIdx.x != gridDim.x - 1) {
|
||||
// Load B
|
||||
@ -50,7 +50,7 @@ __global__ void __launch_bounds__(128)
|
||||
}
|
||||
|
||||
// Store B
|
||||
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||
const auto dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||
const int dst_col_idx =
|
||||
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
@ -65,7 +65,7 @@ __global__ void __launch_bounds__(128)
|
||||
} else {
|
||||
// Load B_scale and B_zero
|
||||
FType b_scale_reg, b_zero_reg;
|
||||
int src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||
auto src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
||||
if (B_zero != nullptr)
|
||||
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
||||
|
@ -62,7 +62,7 @@ template <typename FType, int BLOCK, int N_MATRIX>
|
||||
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
uint32_t n, uint32_t n_matrix,
|
||||
uint32_t matrix_size) {
|
||||
int idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||
auto idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||
|
||||
if (idx >= matrix_size) {
|
||||
return;
|
||||
@ -407,4 +407,4 @@ static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
} // namespace allspark
|
||||
} // namespace allspark
|
||||
|
Loading…
x
Reference in New Issue
Block a user