[Kernel] allow non-contiguous input for marlin kernel (#14658)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin 2025-03-24 21:21:33 +08:00 committed by GitHub
parent 7ffcccfa5c
commit 6b3cc75be0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 20 deletions

View File

@ -42,7 +42,7 @@ namespace marlin {
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {}
int size_k, int lda, int block_rows) {}
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
@ -459,7 +459,7 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int size_k, int lda, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
@ -467,16 +467,19 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}
int cur_block_rows = finish_row - start_row;
int row_stride = size_k * sizeof(half) / 16;
int input_row_stride = lda * sizeof(half) / 16;
int output_row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int offset = row * row_stride;
int input_offset = row * input_row_stride;
int output_offset = row * output_row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
half const* a_row_half =
reinterpret_cast<half const*>(a_int4_ptr + input_offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + output_offset);
int base_k = 0;
@ -537,6 +540,7 @@ __global__ void Marlin(
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int lda, // A.stride(0), equal to prob_k is A is contiguous
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
@ -600,7 +604,7 @@ __global__ void Marlin(
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
@ -631,7 +635,7 @@ __global__ void Marlin(
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
A += 16 * thread_m_blocks * lda / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
@ -643,7 +647,7 @@ __global__ void Marlin(
// A sizes/strides
// stride of the A matrix in global memory
int a_gl_stride = prob_k / 8;
int a_gl_stride = lda / 8;
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
// delta between subsequent A tiles in global memory
@ -1780,8 +1784,8 @@ __global__ void Marlin(
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
use_fp32_reduce); \
num_groups, prob_m, prob_n, prob_k, lda, locks, \
use_atomic_add, use_fp32_reduce); \
} \
}
@ -2071,7 +2075,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace,
int prob_n, int prob_k, int lda, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
@ -2184,8 +2188,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// Permute A columns
int block_rows = div_ceil(prob_m, blocks);
permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows);
A_ptr = a_tmp_ptr;
lda = prob_k;
}
// If we have a full K, then we can run the non-act-order version of Marlin
@ -2244,7 +2249,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
", num_bits = ", num_bits);
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
A_ptr += 16 * thread_m_blocks * (lda / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}
}
@ -2300,7 +2305,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(a.stride(1) == 1, "A.stride(1) is not 1");
// We use int4 (16 bytes) to load A, so A must aligned to 16 bytes
TORCH_CHECK(a.stride(0) % 8 == 0, "A.stride(0) must divisible by 8");
TORCH_CHECK(((uint64_t)a.data_ptr()) % 16 == 0, "A must aligned to 16 bytes");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
@ -2432,7 +2440,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
@ -2443,10 +2451,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
use_fp32_reduce, is_zp_float);
a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
marlin::max_par, use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}

View File

@ -606,6 +606,51 @@ def test_marlin_qqq_gemm(
assert max_diff < 0.04
def test_marlin_gemm_subset_input():
quant_type = scalar_types.uint4b8
group_size = 128
size_m, size_k, size_n = 32, 1024, 2048
big_m = size_m * 2
big_k = size_k * 2
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
has_zp=False,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_opcheck():
size_m = 2048
size_n = 4096