[Kernel] Add marlin_24 unit tests (#4901)
This commit is contained in:
parent
f68470e803
commit
27ce85476e
@ -7,23 +7,32 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights)
|
||||
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
|
||||
marlin_quantize, marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
K_CHUNKS = [128, 256]
|
||||
N_CHUNKS = [64, 128, 256]
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 128, 256]
|
||||
|
||||
MARLIN_24_K_CHUNKS = [128]
|
||||
MARLIN_24_N_CHUNKS = [256]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(1, 7 * 4, 5 * 1),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
@ -31,14 +40,13 @@ MNK_FACTORS = [
|
||||
|
||||
|
||||
def rand_data(shape):
|
||||
data = torch.rand(shape).to(torch.half).cuda()
|
||||
return data
|
||||
return torch.randn(shape, dtype=torch.half, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits,
|
||||
marlin_perm[num_bits])
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@ -136,7 +145,8 @@ def test_marlin_gemm(
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, num_bits, group_size, act_order)
|
||||
|
||||
workspace = MarlinWorkspace(size_n)
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
@ -155,4 +165,55 @@ def test_marlin_gemm(
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert torch.allclose(output, output_ref, rtol=1e-2)
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)
|
||||
|
||||
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
output = ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
num_bits,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
@ -12,6 +12,15 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24.
|
||||
@ -25,15 +34,17 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
|
||||
if self.weight_bits != 4 and self.weight_bits != 8:
|
||||
raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
|
||||
self.weight_bits))
|
||||
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin24, but got group_size of "
|
||||
f"{self.group_size}")
|
||||
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
308
vllm/model_executor/layers/quantization/utils/format_24.py
Normal file
308
vllm/model_executor/layers/quantization/utils/format_24.py
Normal file
@ -0,0 +1,308 @@
|
||||
#
|
||||
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
# GEMM decides upon layout of this matrix, and at the moment for the
|
||||
# sparse GEMM executed on tensor cores, this is layout described by
|
||||
# ColumnMajorInterleaved<2> data structure, in
|
||||
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
# reordering of meta matrix into meta_reordered matrix calculated
|
||||
# according to these segments of CUTLASS code is re-implemented here.
|
||||
# Note that this calculation produces offsets for scattering metadata
|
||||
# matrix elements into reordered metadata matrix elements (or,
|
||||
# equivalently, for gathering reordered metadata matrix element back
|
||||
# into metadata matrix elements).
|
||||
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,
|
||||
device):
|
||||
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
||||
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
||||
|
||||
# Reorder the rows, then swizzle the 2x2 blocks.
|
||||
group_x = 64
|
||||
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
||||
|
||||
dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +
|
||||
(dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +
|
||||
((dst_rows % group_x) // 8) * 4)
|
||||
|
||||
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
||||
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
||||
dst_rows += topright - bottomleft
|
||||
dst_cols -= topright - bottomleft
|
||||
|
||||
# Assumed that meta tensor is to be stored in CUTLASS
|
||||
# InterleavedColumnMajor layout, and reverse engineered
|
||||
# corresponding code to store values into this tensor.
|
||||
interleave = 2
|
||||
cols_maj = dst_cols // interleave
|
||||
cols_min = dst_cols % interleave
|
||||
return (cols_maj * m * interleave + dst_rows * interleave +
|
||||
cols_min).view(-1)
|
||||
|
||||
|
||||
# This function converts dense matrix into sparse semi-structured
|
||||
# representation, producing "compressed" matrix, in the layout used by
|
||||
# CUTLASS backend, and corresponding metadata matrix.
|
||||
def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
if dense.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
|
||||
m, k = dense.shape
|
||||
device = dense.device
|
||||
|
||||
meta_dtype = torch.int8
|
||||
if dense.dtype == torch.int8:
|
||||
meta_dtype = torch.int32
|
||||
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
||||
meta_dtype = torch.int16
|
||||
else:
|
||||
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
||||
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
||||
if quadbits_per_meta_elem not in (4, 8):
|
||||
raise RuntimeError(
|
||||
"Invalid number of elements per meta element calculated")
|
||||
|
||||
if meta_dtype == torch.int32:
|
||||
if m % 16 != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of dense matrix {m} must be divisible by 16")
|
||||
else:
|
||||
if m % 32 != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of dense matrix {m} must be divisible by 32")
|
||||
if k % (4 * quadbits_per_meta_elem) != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
||||
)
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
ksparse = 4
|
||||
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
||||
else:
|
||||
ksparse = 2
|
||||
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
||||
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
||||
|
||||
# Encoding quadruples of True/False values as follows:
|
||||
# [True, True, False, False] -> 0b0100
|
||||
# [True, False, True, False] -> 0b1000
|
||||
# [False, True, True, False] -> 0b1001
|
||||
# [True, False, False, True ] -> 0b1100
|
||||
# [False, True, False, True ] -> 0b1101
|
||||
# [False, False, True, True ] -> 0b1110
|
||||
# Thus, lower two bits in the encoding are index of the True value
|
||||
# at the lowest index in the quadruple, and the higher two bits in
|
||||
# the encoding are index of the other True value in the quadruple.
|
||||
# In case there are less than two True values, than False value or
|
||||
# values at some index or indices are considered True for the
|
||||
# encoding. In case there are more than two True values, then the
|
||||
# excess True value(s) at some indices are considered False for
|
||||
# the encoding. The exact encodings used for these cases are as
|
||||
# follows:
|
||||
# [False, False, False, False] -> 0b1110
|
||||
# [False, False, False, True ] -> 0b1110
|
||||
# [False, False, True, False] -> 0b1110
|
||||
# [False, True, False, False] -> 0b1001
|
||||
# [False, True, True, True ] -> 0b1101
|
||||
# [True, False, False, False] -> 0b1000
|
||||
# [True, False, True, True ] -> 0b1100
|
||||
# [True, True, False, True ] -> 0b0100
|
||||
# [True, True, True, False] -> 0b0100
|
||||
# [True, True, True, True ] -> 0b0100
|
||||
# These particular encodings are chosen, with the help of Espresso
|
||||
# logic minimizer software, for the purpose of minimization of
|
||||
# corresponding Boolean functions, that translate non-zero flags
|
||||
# into encoding bits. Note also possible choices for the first
|
||||
# and last of these encodings were limited only to (0b0100,
|
||||
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
||||
# case.
|
||||
|
||||
expr0 = m0 & m1
|
||||
expr1 = ~m0 & m1
|
||||
expr2 = ~m0 & ~m1
|
||||
bit0 = expr1
|
||||
bit1 = expr2
|
||||
bit2 = expr0 | expr2 | m3
|
||||
bit3 = expr1 | ~m1
|
||||
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
||||
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
sparse0 = dense_4.gather(
|
||||
-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||
else:
|
||||
sparse = dense_2.gather(-1,
|
||||
idxs0.unsqueeze(-1) // 2).view(
|
||||
m,
|
||||
k // 2) # type: ignore[possibly-undefined]
|
||||
|
||||
meta_4 = idxs0 | (idxs1 << 2)
|
||||
meta_n = meta_4.view(
|
||||
(-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
||||
|
||||
if quadbits_per_meta_elem == 4:
|
||||
meta = (meta_n[:, :, 0]
|
||||
| (meta_n[:, :, 1] << 4)
|
||||
| (meta_n[:, :, 2] << 8)
|
||||
| (meta_n[:, :, 3] << 12))
|
||||
elif quadbits_per_meta_elem == 8:
|
||||
meta = (meta_n[:, :, 0]
|
||||
| (meta_n[:, :, 1] << 4)
|
||||
| (meta_n[:, :, 2] << 8)
|
||||
| (meta_n[:, :, 3] << 12)
|
||||
| (meta_n[:, :, 4] << 16)
|
||||
| (meta_n[:, :, 5] << 20)
|
||||
| (meta_n[:, :, 6] << 24)
|
||||
| (meta_n[:, :, 7] << 28))
|
||||
|
||||
# Reorder meta tensor elements.
|
||||
meta_reordered = meta.new_empty(
|
||||
(m * meta_ncols, )) # type: ignore[possibly-undefined]
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device)
|
||||
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
||||
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
||||
|
||||
# This function performs reverse of the function above - it
|
||||
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
# in the layout used by CUTLASS backend, and accompanying metadata
|
||||
# matrix.
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
if sparse.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
|
||||
m, k = sparse.shape
|
||||
device = sparse.device
|
||||
|
||||
if meta_reordered.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
if meta_reordered.device != device:
|
||||
raise RuntimeError(
|
||||
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
||||
)
|
||||
|
||||
meta_dtype = meta_reordered.dtype
|
||||
if meta_dtype not in (torch.int16, torch.int32):
|
||||
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
||||
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
||||
|
||||
ksparse = 4 if sparse.dtype != torch.float else 2
|
||||
|
||||
meta_nrows, meta_ncols = meta_reordered.shape
|
||||
if meta_nrows != m:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
||||
)
|
||||
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
||||
raise RuntimeError(
|
||||
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
||||
"expected according to the number of columns of meta matrix")
|
||||
|
||||
# Undo meta tensor elements reordering.
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device)
|
||||
meta = torch.gather(meta_reordered.view(-1), 0,
|
||||
meta_offsets).view(m, meta_ncols)
|
||||
|
||||
# Unpack sparse tensor back to original dense tensor, using
|
||||
# information provided by meta tensor. Note that torch.float
|
||||
# datatype is handled pretty much the same as
|
||||
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
||||
# value is encoded as if underlying 8 bytes contain four
|
||||
# torch.half/torch.bfloat16 values, where either first two or last
|
||||
# two are zeros.
|
||||
meta_2 = torch.empty(
|
||||
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
||||
dtype=meta_dtype,
|
||||
device=device,
|
||||
)
|
||||
if quadbits_per_meta_elem == 4:
|
||||
meta_2[:, :, 0] = meta & 0b11
|
||||
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
||||
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
||||
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
||||
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
||||
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
||||
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
||||
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
||||
elif quadbits_per_meta_elem == 8:
|
||||
meta_2[:, :, 0] = meta & 0b11
|
||||
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
||||
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
||||
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
||||
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
||||
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
||||
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
||||
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
||||
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
||||
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
||||
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
||||
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
||||
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
||||
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
||||
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
||||
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
||||
|
||||
dense_offsets = meta_2.view(-1) + (
|
||||
torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(
|
||||
-1, 1).repeat(1, 2).view(-1)
|
||||
|
||||
dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)
|
||||
if sparse.dtype != torch.float:
|
||||
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
||||
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
||||
else:
|
||||
dense.view(torch.half).scatter_(0, dense_offsets,
|
||||
sparse.view(torch.half).view(-1))
|
||||
|
||||
return dense.view(m, 2 * k)
|
||||
|
||||
|
||||
def mask_creator(tensor):
|
||||
"""
|
||||
Class for creating N:M sparsity masks.
|
||||
Masks will be created using the N:M ratio, where for every block of
|
||||
M weights, N will be pruned based on ranked weight value. Each mask
|
||||
will correspond to the given tensor.
|
||||
|
||||
:param N: The number of weights in a group to keep
|
||||
:param M: The size of a weight group
|
||||
"""
|
||||
N = 2
|
||||
M = 4
|
||||
|
||||
mask = None
|
||||
# for i, tensor in enumerate(tensors):
|
||||
if tensor.numel() % M != 0:
|
||||
raise ValueError(
|
||||
f"Tensor of size {tensor.shape} can't be evenly divided into "
|
||||
f"{M} groups")
|
||||
|
||||
num_groups = tensor.numel() // M
|
||||
|
||||
# N:M sparsity for linear layers
|
||||
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
||||
index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]
|
||||
|
||||
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
||||
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
||||
|
||||
return mask
|
@ -0,0 +1,58 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
|
||||
#
|
||||
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||
def get_perms_24(num_bits):
|
||||
perm_list = []
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
|
||||
4 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 1 * j for p in perm1])
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
marlin_24_perm = {}
|
||||
marlin_24_scale_perm = {}
|
||||
marlin_24_scale_perm_single = {}
|
||||
for num_bits in [4, 8]:
|
||||
perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
|
||||
marlin_24_perm[num_bits] = perm_24
|
||||
marlin_24_scale_perm[num_bits] = scale_perm_24
|
||||
marlin_24_scale_perm_single[num_bits] = scale_perm_single_24
|
@ -0,0 +1,58 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
|
||||
#
|
||||
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||
def get_perms(num_bits):
|
||||
perm_list = []
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
marlin_perm = {}
|
||||
marlin_scale_perm = {}
|
||||
marlin_scale_perm_single = {}
|
||||
for num_bits in [4, 8]:
|
||||
perm, scale_perm, scale_perm_single = get_perms(num_bits)
|
||||
marlin_perm[num_bits] = perm
|
||||
marlin_scale_perm[num_bits] = scale_perm
|
||||
marlin_scale_perm_single[num_bits] = scale_perm_single
|
@ -1,79 +1,28 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE)
|
||||
from vllm.model_executor.layers.quantization.utils.format_24 import (
|
||||
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
|
||||
marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
|
||||
__cuda_arch = torch.cuda.get_device_capability()
|
||||
|
||||
MARLIN_TILE = 16
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
return __cuda_arch[0] >= 8
|
||||
|
||||
|
||||
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
|
||||
#
|
||||
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||
def _get_perms(num_bits):
|
||||
perm_list = []
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
_perm = {}
|
||||
_scale_perm = {}
|
||||
_scale_perm_single = {}
|
||||
for num_bits in [4, 8]:
|
||||
perm, scale_perm, scale_perm_single = _get_perms(num_bits)
|
||||
_perm[num_bits] = perm
|
||||
_scale_perm[num_bits] = scale_perm
|
||||
_scale_perm_single[num_bits] = scale_perm_single
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
tile=GPTQ_MARLIN_TILE):
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
@ -83,15 +32,14 @@ def marlin_permute_weights(q_w,
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape(
|
||||
(-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape)
|
||||
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits):
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits)
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
@ -101,7 +49,6 @@ def marlin_weights(q_w, size_k, size_n, num_bits):
|
||||
|
||||
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||
dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
@ -110,15 +57,12 @@ def marlin_weights(q_w, size_k, size_n, num_bits):
|
||||
return q_packed
|
||||
|
||||
|
||||
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
|
||||
def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
|
||||
scale_perm_single):
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(_scale_perm[num_bits])))[:,
|
||||
_scale_perm[num_bits]]
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape(
|
||||
(-1,
|
||||
len(_scale_perm_single[num_bits])))[:,
|
||||
_scale_perm_single[num_bits]]
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
@ -148,8 +92,11 @@ def marlin_quantize(
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Reformat to marlin
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
|
||||
marlin_perm[num_bits])
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,
|
||||
marlin_scale_perm[num_bits],
|
||||
marlin_scale_perm_single[num_bits])
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||
@ -159,15 +106,118 @@ def marlin_quantize(
|
||||
return res_list
|
||||
|
||||
|
||||
def inject_24(w, size_k, size_n):
|
||||
assert w.shape == (size_k, size_n)
|
||||
|
||||
mask = mask_creator(w.t()).t().cuda().bool()
|
||||
|
||||
return (mask * w).contiguous(), mask.contiguous()
|
||||
|
||||
|
||||
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
||||
BLOCK_SIZE = 4
|
||||
MAX_NON_ZEROS = 2
|
||||
|
||||
w = w.t().contiguous()
|
||||
|
||||
print("check_24: w.shape = {}".format(w.shape))
|
||||
|
||||
num_rows, num_cols = w.shape
|
||||
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
||||
if _verbose:
|
||||
print(f"Sampled row idxs = {sampled_row_idxs}")
|
||||
|
||||
total_segments = 0
|
||||
non_24_segments = 0
|
||||
for i in sampled_row_idxs:
|
||||
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
||||
total_segments += 1
|
||||
block = w[i, j:j + BLOCK_SIZE]
|
||||
num_nonzero = torch.count_nonzero(block)
|
||||
if num_nonzero > MAX_NON_ZEROS:
|
||||
print("i = {} j = {} block = {}".format(i, j, block))
|
||||
non_24_segments += 1
|
||||
|
||||
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
||||
|
||||
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
|
||||
assert q_24.shape == (size_k, size_n)
|
||||
|
||||
# Remove zp to normalize over 0
|
||||
max_q_val = (1 << num_bits) - 1
|
||||
zp = (max_q_val + 1) // 2
|
||||
q_24_no_zp = q_24 - zp
|
||||
|
||||
# Compress
|
||||
q_24_no_zp = q_24_no_zp.t().contiguous()
|
||||
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
|
||||
q_24_no_zp)
|
||||
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
||||
|
||||
# Restore zp
|
||||
q_24_comp = q_24_no_zp_comp + zp
|
||||
|
||||
# Resize meta to its actual shape (without moving any data)
|
||||
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
||||
|
||||
return q_24_comp, meta
|
||||
|
||||
|
||||
def marlin_24_quantize(
|
||||
w: torch.Tensor,
|
||||
num_bits: int,
|
||||
group_size: int,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Inject 2:4 sparsity
|
||||
w_24, mask_24 = inject_24(w, size_k, size_n)
|
||||
|
||||
# Quantize
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
|
||||
num_bits,
|
||||
group_size,
|
||||
act_order=False)
|
||||
|
||||
# Compress quantized weight
|
||||
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
|
||||
num_bits)
|
||||
size_k_comp = size_k // 2
|
||||
|
||||
# Reformat to marlin
|
||||
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
|
||||
num_bits, marlin_24_perm[num_bits])
|
||||
marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
|
||||
marlin_24_scale_perm[num_bits],
|
||||
marlin_24_scale_perm_single[num_bits])
|
||||
|
||||
# Create result
|
||||
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
|
||||
def __init__(self, out_features):
|
||||
assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), (
|
||||
"out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}"
|
||||
.format(out_features, GPTQ_MARLIN_MIN_THREAD_N))
|
||||
def __init__(self, out_features, min_thread_n, max_parallel):
|
||||
assert (out_features % min_thread_n == 0), (
|
||||
"out_features = {} is undivisible by min_thread_n = {}".format(
|
||||
out_features, min_thread_n))
|
||||
|
||||
max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) *
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
|
||||
|
||||
self.scratch = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
|
Loading…
x
Reference in New Issue
Block a user