[Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626)

This commit is contained in:
alexm-nm 2024-05-08 20:14:31 -04:00 committed by GitHub
parent 8b9241be3a
commit e288df0632
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,7 +115,8 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on mask)
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
@ -1275,13 +1276,22 @@ typedef struct {
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t thread_configs[] = {
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default (max cache usage)
{64, 128, 128}, // Reduce N, reduce warps
{128, 64, 128}, // Reduce N more, but increase K
{128, 128, 256},
{64, 128, 128},
{128, 64, 128},
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128},
{128, 64, 128},
};
@ -1397,11 +1407,21 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
for (auto th_config : thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}
@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {