Marlin 24 prefill performance improvement (about 25% better on average) (#4983)

This commit is contained in:
Alexander Matveev 2024-05-23 02:39:27 -04:00 committed by GitHub
parent ee3eea0a1b
commit 6066253296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 32 deletions

View File

@ -6,9 +6,13 @@ from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_quantize) MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights) gptq_pack, quantize_weights, sort_weights)
@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
marlin_rand_perm, marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order) ) = marlin_quantize(b, num_bits, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
# GPTQ quant # GPTQ quant
(w_ref, q_w, s, g_idx, (w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order) rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare # Prepare
marlin_workspace = MarlinWorkspace(size_n) marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
globals = { globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref, "marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w, "marlin_q_w": marlin_q_w,
"marlin_s": marlin_s, "marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx, "marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices, "marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm, "marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq, "q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices, "repack_sort_indices": repack_sort_indices,
"num_bits": num_bits, # Kernels
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"is_k_full": is_k_full,
"a": a,
"a_tmp": a_tmp,
"gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack, "gptq_marlin_repack": ops.gptq_marlin_repack,
"marlin_workspace": marlin_workspace,
} }
min_run_time = 1 min_run_time = 1
@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
description="gptq_marlin_gemm", description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
@ -135,8 +170,20 @@ def main(args):
continue continue
for act_order in ACT_ORDER_OPTS: for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue
for is_k_full in K_FULL_OPTS: for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len( if len(
args.limit_group_size args.limit_group_size
@ -159,7 +206,7 @@ def main(args):
# For quick benchmarking use: # For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 # python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
# #
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -178,6 +225,9 @@ if __name__ == "__main__":
parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -48,12 +48,12 @@ namespace marlin_24 {
// than 1 warp per schedule allows some more latency hiding. At the same time, // than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles. // we want relatively few warps to have many registers per warp and small tiles.
static constexpr int THREADS = 256; static constexpr int THREADS = 256;
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int STAGES = 4;
static constexpr int min_thread_n = 128; static constexpr int min_thread_n = 128;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
@ -736,10 +736,10 @@ __global__ void Marlin_24(
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages); slice_iters >= stages);
matmul(pipe);
wait_for_stage(); wait_for_stage();
fetch_to_registers(pipe + 1, (pipe + 1) % stages); fetch_to_registers(pipe + 1, (pipe + 1) % stages);
matmul(pipe);
pipe++; pipe++;
slice_iters--; slice_iters--;
@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization // than better compute utilization
thread_k = 128; thread_k = 128;
thread_m = 128; thread_m = 128;
} else { } else if (prob_n <= 256) {
thread_k = 64; thread_k = 64;
thread_m = 256; thread_m = 256;
} else {
thread_k = 32;
thread_m = 512;
} }
} }
@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
constexpr int max_m_blocks = 4;
int* locks = (int*)workspace; int* locks = (int*)workspace;
for (int i = 0; i < tot_n_blocks; i += 4) { for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
int thread_n_blocks = tot_n_blocks - i; int thread_n_blocks = tot_n_blocks - i;
prob_n = tot_n - 16 * i; prob_n = tot_n - 16 * i;
int par = 1; int par = 1;
if (thread_n_blocks > 4) { if (thread_n_blocks > max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_n_blocks - pad) / 64; par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
if (par > max_par) par = max_par; if (par > max_par) par = max_par;
prob_n = 64 * par; prob_n = (max_m_blocks * 16) * par;
i += 4 * (par - 1); i += max_m_blocks * (par - 1);
thread_n_blocks = 4; thread_n_blocks = max_m_blocks;
} }
// For compilation speed, we only define the kernel configurations that have // For compilation speed, we only define the kernel configurations that have
@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
if (false) { if (false) {
} // BMxBNxBK, group } // BMxBNxBK, group
// 4-bit // 4-bit
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4(4, 16, 4, 2, -1) CALL_IF_2_4(4, 16, 4, 2, -1)
CALL_IF_2_4(4, 16, 4, 2, 4) CALL_IF_2_4(4, 16, 4, 2, 4)
CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(4, 32, 2, 1, 4)
CALL_IF_2_4(4, 32, 3, 1, -1)
CALL_IF_2_4(4, 32, 3, 1, 4)
CALL_IF_2_4(4, 32, 4, 1, -1)
CALL_IF_2_4(4, 32, 4, 1, 4)
// 8-bit // 8-bit
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4(8, 16, 3, 2, 4) CALL_IF_2_4(8, 16, 3, 2, 4)
CALL_IF_2_4(8, 16, 4, 2, -1) CALL_IF_2_4(8, 16, 4, 2, -1)
CALL_IF_2_4(8, 16, 4, 2, 4) CALL_IF_2_4(8, 16, 4, 2, 4)
CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(8, 32, 2, 1, 4)
CALL_IF_2_4(8, 32, 3, 1, -1)
CALL_IF_2_4(8, 32, 3, 1, 4)
CALL_IF_2_4(8, 32, 4, 1, -1)
CALL_IF_2_4(8, 32, 4, 1, 4)
else { else {
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
", " + str(prob_k) + ", " + str(prob_n) + "]" + ", " + str(prob_k) + ", " + str(prob_n) + "]" +
@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int thread_k = -1; int thread_k = -1;
int thread_m = -1; int thread_m = -1;
int sms = -1; int sms = -1;
int max_par = 16; int max_par = marlin_24::max_par;
int groupsize = -1; int groupsize = -1;
if (b_scales.size(0) > 1) { if (b_scales.size(0) > 1) {

View File

@ -27,7 +27,7 @@ MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256] MARLIN_N_CHUNKS = [64, 128, 256]
MARLIN_24_K_CHUNKS = [128] MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [256] MARLIN_24_N_CHUNKS = [512]
MNK_FACTORS = [ MNK_FACTORS = [
(1, 1, 1), (1, 1, 1),

View File

@ -15,7 +15,7 @@ logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE = 16 GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128 GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 16 GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
@ -53,14 +53,14 @@ class GPTQMarlin24Config(QuantizationConfig):
self.tile_size = 16 self.tile_size = 16
# Min out_features dim # Min out_features dim
self.min_n_threads = 128 self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim # Min in_features dim
self.min_k_threads = 128 self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large # Max parallel problems to solve at once (improves large
# batch performance) # batch performance)
self.max_parallel = 16 self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels. # Permutation length used by the marlin kernels.
self.perm_len = 1024 self.perm_len = 1024