Marlin 24 prefill performance improvement (about 25% better on average) (#4983)
This commit is contained in:
parent
ee3eea0a1b
commit
6066253296
@ -6,9 +6,13 @@ from benchmark_shapes import WEIGHT_SHAPES
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
MarlinWorkspace, marlin_quantize)
|
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
gptq_pack, quantize_weights, sort_weights)
|
gptq_pack, quantize_weights, sort_weights)
|
||||||
|
|
||||||
@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
|||||||
marlin_rand_perm,
|
marlin_rand_perm,
|
||||||
) = marlin_quantize(b, num_bits, group_size, act_order)
|
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||||
|
|
||||||
|
# Marlin_24 quant
|
||||||
|
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||||
|
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||||
|
|
||||||
# GPTQ quant
|
# GPTQ quant
|
||||||
(w_ref, q_w, s, g_idx,
|
(w_ref, q_w, s, g_idx,
|
||||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||||
@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
|||||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||||
|
|
||||||
# Prepare
|
# Prepare
|
||||||
marlin_workspace = MarlinWorkspace(size_n)
|
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
|
# Gen params
|
||||||
|
"num_bits": num_bits,
|
||||||
|
"group_size": group_size,
|
||||||
|
"size_m": size_m,
|
||||||
|
"size_n": size_n,
|
||||||
|
"size_k": size_k,
|
||||||
|
"a": a,
|
||||||
|
"a_tmp": a_tmp,
|
||||||
|
# Marlin params
|
||||||
"marlin_w_ref": marlin_w_ref,
|
"marlin_w_ref": marlin_w_ref,
|
||||||
"marlin_q_w": marlin_q_w,
|
"marlin_q_w": marlin_q_w,
|
||||||
"marlin_s": marlin_s,
|
"marlin_s": marlin_s,
|
||||||
"marlin_g_idx": marlin_g_idx,
|
"marlin_g_idx": marlin_g_idx,
|
||||||
"marlin_sort_indices": marlin_sort_indices,
|
"marlin_sort_indices": marlin_sort_indices,
|
||||||
"marlin_rand_perm": marlin_rand_perm,
|
"marlin_rand_perm": marlin_rand_perm,
|
||||||
|
"marlin_workspace": marlin_workspace,
|
||||||
|
"is_k_full": is_k_full,
|
||||||
|
# Marlin_24 params
|
||||||
|
"marlin_24_w_ref": marlin_24_w_ref,
|
||||||
|
"marlin_24_q_w_comp": marlin_24_q_w_comp,
|
||||||
|
"marlin_24_meta": marlin_24_meta,
|
||||||
|
"marlin_24_s": marlin_24_s,
|
||||||
|
"marlin_24_workspace": marlin_24_workspace,
|
||||||
|
# GPTQ params
|
||||||
"q_w_gptq": q_w_gptq,
|
"q_w_gptq": q_w_gptq,
|
||||||
"repack_sort_indices": repack_sort_indices,
|
"repack_sort_indices": repack_sort_indices,
|
||||||
"num_bits": num_bits,
|
# Kernels
|
||||||
"group_size": group_size,
|
|
||||||
"size_m": size_m,
|
|
||||||
"size_n": size_n,
|
|
||||||
"size_k": size_k,
|
|
||||||
"is_k_full": is_k_full,
|
|
||||||
"a": a,
|
|
||||||
"a_tmp": a_tmp,
|
|
||||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||||
|
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||||
"marlin_workspace": marlin_workspace,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
|||||||
description="gptq_marlin_gemm",
|
description="gptq_marlin_gemm",
|
||||||
).blocked_autorange(min_run_time=min_run_time))
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||||
|
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_24_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt=
|
stmt=
|
||||||
@ -135,8 +170,20 @@ def main(args):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for act_order in ACT_ORDER_OPTS:
|
for act_order in ACT_ORDER_OPTS:
|
||||||
|
if len(args.limit_act_order
|
||||||
|
) > 0 and act_order not in args.limit_act_order:
|
||||||
|
continue
|
||||||
|
|
||||||
for is_k_full in K_FULL_OPTS:
|
for is_k_full in K_FULL_OPTS:
|
||||||
|
if len(args.limit_k_full
|
||||||
|
) > 0 and is_k_full not in args.limit_k_full:
|
||||||
|
continue
|
||||||
|
|
||||||
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||||
|
if len(args.limit_num_bits
|
||||||
|
) > 0 and num_bits not in args.limit_num_bits:
|
||||||
|
continue
|
||||||
|
|
||||||
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||||
if len(
|
if len(
|
||||||
args.limit_group_size
|
args.limit_group_size
|
||||||
@ -159,7 +206,7 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
# For quick benchmarking use:
|
# For quick benchmarking use:
|
||||||
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
|
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
|
||||||
#
|
#
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -178,6 +225,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -48,12 +48,12 @@ namespace marlin_24 {
|
|||||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||||
// we want relatively few warps to have many registers per warp and small tiles.
|
// we want relatively few warps to have many registers per warp and small tiles.
|
||||||
static constexpr int THREADS = 256;
|
static constexpr int THREADS = 256;
|
||||||
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
|
static constexpr int STAGES = 4;
|
||||||
|
|
||||||
static constexpr int min_thread_n = 128;
|
static constexpr int min_thread_n = 128;
|
||||||
|
|
||||||
static constexpr int tile_size = 16;
|
static constexpr int tile_size = 16;
|
||||||
static constexpr int max_par = 16;
|
static constexpr int max_par = 64;
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
@ -736,10 +736,10 @@ __global__ void Marlin_24(
|
|||||||
for (int pipe = 0; pipe < stages;) {
|
for (int pipe = 0; pipe < stages;) {
|
||||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||||
slice_iters >= stages);
|
slice_iters >= stages);
|
||||||
|
matmul(pipe);
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
|
|
||||||
fetch_to_registers(pipe + 1, (pipe + 1) % stages);
|
fetch_to_registers(pipe + 1, (pipe + 1) % stages);
|
||||||
matmul(pipe);
|
|
||||||
|
|
||||||
pipe++;
|
pipe++;
|
||||||
slice_iters--;
|
slice_iters--;
|
||||||
@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|||||||
// than better compute utilization
|
// than better compute utilization
|
||||||
thread_k = 128;
|
thread_k = 128;
|
||||||
thread_m = 128;
|
thread_m = 128;
|
||||||
} else {
|
} else if (prob_n <= 256) {
|
||||||
thread_k = 64;
|
thread_k = 64;
|
||||||
thread_m = 256;
|
thread_m = 256;
|
||||||
|
} else {
|
||||||
|
thread_k = 32;
|
||||||
|
thread_m = 512;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const int4* s_ptr = (const int4*)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
|
|
||||||
|
constexpr int max_m_blocks = 4;
|
||||||
|
|
||||||
int* locks = (int*)workspace;
|
int* locks = (int*)workspace;
|
||||||
for (int i = 0; i < tot_n_blocks; i += 4) {
|
for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
|
||||||
int thread_n_blocks = tot_n_blocks - i;
|
int thread_n_blocks = tot_n_blocks - i;
|
||||||
prob_n = tot_n - 16 * i;
|
prob_n = tot_n - 16 * i;
|
||||||
int par = 1;
|
int par = 1;
|
||||||
if (thread_n_blocks > 4) {
|
if (thread_n_blocks > max_m_blocks) {
|
||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * thread_n_blocks - pad) / 64;
|
par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
|
||||||
if (par > max_par) par = max_par;
|
if (par > max_par) par = max_par;
|
||||||
prob_n = 64 * par;
|
prob_n = (max_m_blocks * 16) * par;
|
||||||
i += 4 * (par - 1);
|
i += max_m_blocks * (par - 1);
|
||||||
thread_n_blocks = 4;
|
thread_n_blocks = max_m_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For compilation speed, we only define the kernel configurations that have
|
// For compilation speed, we only define the kernel configurations that have
|
||||||
@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|||||||
if (false) {
|
if (false) {
|
||||||
} // BMxBNxBK, group
|
} // BMxBNxBK, group
|
||||||
// 4-bit
|
// 4-bit
|
||||||
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
||||||
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||||
|
|
||||||
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
||||||
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||||
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||||
@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|||||||
CALL_IF_2_4(4, 16, 4, 2, -1)
|
CALL_IF_2_4(4, 16, 4, 2, -1)
|
||||||
CALL_IF_2_4(4, 16, 4, 2, 4)
|
CALL_IF_2_4(4, 16, 4, 2, 4)
|
||||||
|
|
||||||
|
CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
|
||||||
|
CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
|
||||||
|
CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
|
||||||
|
CALL_IF_2_4(4, 32, 2, 1, 4)
|
||||||
|
CALL_IF_2_4(4, 32, 3, 1, -1)
|
||||||
|
CALL_IF_2_4(4, 32, 3, 1, 4)
|
||||||
|
CALL_IF_2_4(4, 32, 4, 1, -1)
|
||||||
|
CALL_IF_2_4(4, 32, 4, 1, 4)
|
||||||
|
|
||||||
// 8-bit
|
// 8-bit
|
||||||
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
||||||
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||||
|
|
||||||
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
||||||
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||||
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||||
@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|||||||
CALL_IF_2_4(8, 16, 3, 2, 4)
|
CALL_IF_2_4(8, 16, 3, 2, 4)
|
||||||
CALL_IF_2_4(8, 16, 4, 2, -1)
|
CALL_IF_2_4(8, 16, 4, 2, -1)
|
||||||
CALL_IF_2_4(8, 16, 4, 2, 4)
|
CALL_IF_2_4(8, 16, 4, 2, 4)
|
||||||
|
|
||||||
|
CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
|
||||||
|
CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
|
||||||
|
CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
|
||||||
|
CALL_IF_2_4(8, 32, 2, 1, 4)
|
||||||
|
CALL_IF_2_4(8, 32, 3, 1, -1)
|
||||||
|
CALL_IF_2_4(8, 32, 3, 1, 4)
|
||||||
|
CALL_IF_2_4(8, 32, 4, 1, -1)
|
||||||
|
CALL_IF_2_4(8, 32, 4, 1, 4)
|
||||||
else {
|
else {
|
||||||
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
||||||
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
||||||
@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
int thread_k = -1;
|
int thread_k = -1;
|
||||||
int thread_m = -1;
|
int thread_m = -1;
|
||||||
int sms = -1;
|
int sms = -1;
|
||||||
int max_par = 16;
|
int max_par = marlin_24::max_par;
|
||||||
|
|
||||||
int groupsize = -1;
|
int groupsize = -1;
|
||||||
if (b_scales.size(0) > 1) {
|
if (b_scales.size(0) > 1) {
|
||||||
|
@ -27,7 +27,7 @@ MARLIN_K_CHUNKS = [128]
|
|||||||
MARLIN_N_CHUNKS = [64, 128, 256]
|
MARLIN_N_CHUNKS = [64, 128, 256]
|
||||||
|
|
||||||
MARLIN_24_K_CHUNKS = [128]
|
MARLIN_24_K_CHUNKS = [128]
|
||||||
MARLIN_24_N_CHUNKS = [256]
|
MARLIN_24_N_CHUNKS = [512]
|
||||||
|
|
||||||
MNK_FACTORS = [
|
MNK_FACTORS = [
|
||||||
(1, 1, 1),
|
(1, 1, 1),
|
||||||
|
@ -15,7 +15,7 @@ logger = init_logger(__name__)
|
|||||||
GPTQ_MARLIN_24_TILE = 16
|
GPTQ_MARLIN_24_TILE = 16
|
||||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||||
GPTQ_MARLIN_24_MAX_PARALLEL = 16
|
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||||
|
|
||||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||||
@ -53,14 +53,14 @@ class GPTQMarlin24Config(QuantizationConfig):
|
|||||||
self.tile_size = 16
|
self.tile_size = 16
|
||||||
|
|
||||||
# Min out_features dim
|
# Min out_features dim
|
||||||
self.min_n_threads = 128
|
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
||||||
|
|
||||||
# Min in_features dim
|
# Min in_features dim
|
||||||
self.min_k_threads = 128
|
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
||||||
|
|
||||||
# Max parallel problems to solve at once (improves large
|
# Max parallel problems to solve at once (improves large
|
||||||
# batch performance)
|
# batch performance)
|
||||||
self.max_parallel = 16
|
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
||||||
|
|
||||||
# Permutation length used by the marlin kernels.
|
# Permutation length used by the marlin kernels.
|
||||||
self.perm_len = 1024
|
self.perm_len = 1024
|
||||||
|
Loading…
x
Reference in New Issue
Block a user