[Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626)
This commit is contained in:
parent
8b9241be3a
commit
e288df0632
@ -115,7 +115,8 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on mask)
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
};
|
||||
|
||||
// Since multiple threadblocks may process parts of the same column slice, we
|
||||
// finally have to globally reduce over the results. As the striped partitioning
|
||||
// minimizes the number of such reductions and our outputs are usually rather
|
||||
// small, we perform this reduction serially in L2 cache.
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
@ -1275,13 +1276,22 @@ typedef struct {
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
thread_config_t thread_configs[] = {
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256}, // Default (max cache usage)
|
||||
{64, 128, 128}, // Reduce N, reduce warps
|
||||
{128, 64, 128}, // Reduce N more, but increase K
|
||||
{128, 128, 256},
|
||||
{64, 128, 128},
|
||||
{128, 64, 128},
|
||||
};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128},
|
||||
{128, 64, 128},
|
||||
|
||||
};
|
||||
|
||||
@ -1397,13 +1407,23 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
int max_shared_mem) {
|
||||
int max_m_blocks = 4;
|
||||
while (max_m_blocks > 0) {
|
||||
for (auto th_config : thread_configs) {
|
||||
if (prob_m <= 16) {
|
||||
for (auto th_config : small_batch_thread_configs) {
|
||||
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full,
|
||||
max_shared_mem)) {
|
||||
return exec_config_t{max_m_blocks, th_config};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto th_config : large_batch_thread_configs) {
|
||||
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full,
|
||||
max_shared_mem)) {
|
||||
return exec_config_t{max_m_blocks, th_config};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
|
||||
"GPU cache. This may "
|
||||
@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
|
||||
}
|
||||
CALL_IF(4, 32, 2, 256)
|
||||
CALL_IF(4, 16, 4, 256)
|
||||
CALL_IF(4, 8, 8, 256)
|
||||
CALL_IF(4, 8, 4, 128)
|
||||
CALL_IF(4, 4, 8, 128)
|
||||
CALL_IF(8, 32, 2, 256)
|
||||
CALL_IF(8, 16, 4, 256)
|
||||
CALL_IF(8, 8, 8, 256)
|
||||
CALL_IF(8, 8, 4, 128)
|
||||
CALL_IF(8, 4, 8, 128)
|
||||
else {
|
||||
|
Loading…
x
Reference in New Issue
Block a user