[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2024-09-23 13:46:26 -04:00 committed by GitHub
parent ee5f34b1c2
commit 86e9c8df29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1005 additions and 246 deletions

View File

@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")

View File

@ -4,8 +4,10 @@ import itertools
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple
import pandas as pd
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
@ -84,6 +86,10 @@ def loop_over_weights(
fn(a, w_ref, w_q, w_s)
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
def bench(atype: torch.dtype,
wtype: ScalarType,
group_size: int,
@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
sub_label: str,
benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
global _SWEEP_SCHEDULES_RESULTS
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"
@ -163,6 +171,11 @@ def bench(atype: torch.dtype,
best_schedule = None
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):
schedule_M = int(schedule.split("_")[0].split("x")[1])
# Prune known bad schedules
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue
def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a,
@ -175,6 +188,20 @@ def bench(atype: torch.dtype,
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))
results_row = {
"M": m,
"K": k,
"N": n,
"group_size": group_size,
"schedule": schedule,
"median": res.median,
}
if _SWEEP_SCHEDULES_RESULTS is None:
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
columns=results_row.keys())
_SWEEP_SCHEDULES_RESULTS.\
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
best = res
@ -235,18 +262,22 @@ def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
m_increment, k_increment, n_increment = \
[int(x) for x in args.dim_increment.split(",")]
Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
MKNs = list(product(Ms, Ks, Ns))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
@ -333,6 +364,9 @@ Benchmark Machete GEMM.
action="store_true",
help="Run a sweep over all supported schedules",
)
parser.add_argument("--sweep-csv-out",
help="CSV to store sweep results",
default="sch_sweep_results.csv")
subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench")
@ -342,12 +376,21 @@ Benchmark Machete GEMM.
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.add_argument(
"--dim-start",
type=str,
required=True,
help="Start value for M,K,N as common separated list")
range_parser.add_argument(
"--dim-end",
type=str,
required=True,
help="End value (inclusive) for M,K,N as common separated list")
range_parser.add_argument(
"--dim-increment",
type=str,
required=True,
help="Increment value for M,K,N as common separated list")
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
@ -369,4 +412,9 @@ Benchmark Machete GEMM.
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
args.func(args)
if _SWEEP_SCHEDULES_RESULTS is not None:
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)

View File

@ -0,0 +1 @@
pandas

View File

@ -67,9 +67,15 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {

View File

@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
}; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,

88
csrc/permute_cols.cu Normal file
View File

@ -0,0 +1,88 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
static constexpr int default_threads = 256;
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = std::max(finish_row - start_row, 0);
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto dev = A.get_device();
auto stream = at::cuda::getCurrentCUDAStream(dev);
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
"Currently only 16bit types are supported");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = A.view({-1, A.size(-1)});
torch::Tensor D = torch::empty_like(A);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int block_rows = div_ceil(A_2d.size(0), sms);
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
A_2d.size(0), A_2d.size(1), block_rows);
return D;
}

View File

@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass
@dataclass(frozen=True)
class ScheduleConfig:
tile_shape_mn: Tuple[int, int]
cluster_shape_mnk: Tuple[int, int, int]
@ -328,56 +328,137 @@ def generate():
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
schedules = [
ScheduleConfig(
tile_shape_mn=tile_shape_mn,
cluster_shape_mnk=cluster_shape_mnk,
kernel_schedule=kernel_schedule,
epilogue_schedule=epilogue_schedule,
tile_scheduler=tile_scheduler,
) for tile_shape_mn, cluster_shape_mnk in (
((128, 16), (1, 1, 1)),
((128, 32), (1, 1, 1)),
((128, 64), (1, 1, 1)),
((128, 128), (1, 1, 1)),
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
for tile_scheduler in (TileSchedulerType.StreamK, )
]
schedule_common_params = dict(
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
("M > 64",
#### M = 257+
(
"M > 256 && K <= 16384 && N <= 4096",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
("M > 32",
(
"M > 256",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 129-256
(
"M > 128 && K <= 4096 && N <= 4096",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
("M > 16",
(
"M > 128 && K <= 8192 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 128",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 65-128
(
"M > 64 && K <= 4069 && N <= 4069",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(None,
ScheduleConfig(tile_shape_mn=(128, 16),
(
"M > 64 && K <= 4069 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K >= 8192 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 33-64
(
"M > 32 && K <= 6144 && N <= 6144",
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK))
**schedule_common_params # type: ignore
)),
(
"M > 32 && K >= 16384 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 32",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 17-32
(
"M > 16 && K <= 12288 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 16",
ScheduleConfig(
tile_shape_mn=(256, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 1-16
(
"N >= 26624",
ScheduleConfig(
tile_shape_mn=(256, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
(
None,
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
]
schedules = list(set([x[1] for x in default_heuristic]))
impl_configs = []
GPTQ_kernel_type_configs = list(

View File

@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size = maybe_group_size.value_or(K);
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);

View File

@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));
args.group_size);
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");

View File

@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// clang-format on
// Allocate output
torch::Tensor D = torch::empty_like(B);
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));

View File

@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "

View File

@ -31,6 +31,8 @@ MNK_SHAPES = [
(257, 4224, 4160),
(257, 4096, 4096),
(64, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]
ACT_TYPES = [torch.float16, torch.bfloat16]
@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
output_ref = torch.matmul(a, w_ref)
for schedule in ops.machete_supported_schedules(wtype):
print(f"Testing schedule {schedule}")
output = ops.machete_gemm(
a,
b_q=w_q_machete,

View File

@ -0,0 +1,15 @@
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm])

View File

@ -438,7 +438,8 @@ try:
@torch.library.register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight)
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# TODO: has to be a better way to do this
try:
torch.ops._C.permute_cols # noqa B018
@torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
except Exception:
pass
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return torch.ops._C.permute_cols(a, perm)
# fp8
def scaled_fp8_quant(
input: torch.Tensor,

View File

@ -7,10 +7,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp)
replace_parameter(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)

View File

@ -1,17 +1,16 @@
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Set
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
marlin_repeat_scales_on_all_ranks)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
def __init__(self,
strategy: str,
@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=False,
has_g_idx=self.has_g_idx
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsWNA16",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)
scales_and_zp_size = input_size // group_size
if partition_scales:
@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.group_size = group_size
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.weight_packed.device
# Allocate marlin workspace.
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Update for kernel
layer.weight_packed = torch.nn.Parameter(
layer.weight_packed.t().contiguous(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_gptq_marlin_linear(
input=x,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)

View File

@ -1,7 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: Set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQMarlinLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
)
# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.zp = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "scales", marlin_scales)
self.kernel.process_weights_after_loading(layer)
def apply(
self,
@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_gptq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias,
)
return self.kernel.apply_weights(layer, x, bias)
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
@ -506,11 +461,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
replace_tensor(layer, "w13_g_idx_sort_indices",
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_tensor(layer, "w2_g_idx_sort_indices",
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w13_scales", marlin_w13_scales)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w2_scales", marlin_w2_scales)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
def apply(
self,

View File

@ -0,0 +1,83 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)

View File

@ -0,0 +1,72 @@
import os
from typing import List, Optional, Type
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
]
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))

View File

@ -0,0 +1,118 @@
from functools import partial
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
.to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if c.act_type in [torch.float16, torch.bfloat16] \
and c.partition_weight_shape[0] % 8 == 0:
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@ -0,0 +1,132 @@
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, f"Quant type ({c.weight_type}) not supported by"\
f" Marlin, supported types are: {quant_types}"
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Marlin, supported group sizes are: "\
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
return check_marlin_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1],
c.full_weight_shape[1],
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias)

View File

@ -0,0 +1,3 @@
from .layer_utils import replace_parameter, update_tensor_inplace
__all__ = ['update_tensor_inplace', 'replace_parameter']

View File

@ -0,0 +1,33 @@
from typing import Union
import torch
def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
assert dst.dtype == src.dtype, "Tensors must have the same dtype"
# update tensor shape and stride
dst.as_strided_(src.shape, src.stride())
# If not the same underlying storage move tensor data
if dst.data_ptr() != src.data_ptr():
dst.copy_(src)
del src
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_parameter(mod: torch.nn.Module, name: str,
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
old = getattr(mod, name)
if old.dtype == new.dtype and \
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
# If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same
update_tensor_inplace(old, new)
else:
# Fallback re-register parameter
if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new)
mod.register_parameter(name, torch.nn.Parameter(new))

View File

@ -0,0 +1,30 @@
from typing import List, Optional, Tuple
import torch
from vllm.scalar_type import ScalarType, scalar_types
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
if zero_points:
return [scalar_types.uint4, scalar_types.uint8]
else:
return [scalar_types.uint4b8, scalar_types.uint8b128]
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
return [torch.float16, torch.bfloat16]
def check_machete_supports_shape(in_features: int, out_featrues: int) \
-> Tuple[bool, Optional[str]]:
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
return False, "Input features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
return False, "Output features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}"
return True, None

View File

@ -120,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
"with --quantization gptq.")
def check_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
except ValueError as e:
return False, e.__str__()
return True, None
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
@ -148,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
requires_grad=False)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
@ -240,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(layer: torch.nn.Module, name: str,
new_t: torch.Tensor) -> None:
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,

View File

@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
}
def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1
new_shape_perm = list(w_q_perm.shape)
assert w_q_perm.shape[-1] % pack_factor == 0
new_shape_perm[-1] //= pack_factor
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
for i in range(pack_factor):
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
return res.permute(inv_perm)
def unpack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1
new_shape_perm = list(w_q_perm.shape)
new_shape_perm[-1] *= pack_factor
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
for i in range(pack_factor):
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj

View File

@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
marlin_tile_size=self.marlin_tile_size)
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
output_dim: int, **kwargs) -> BasevLLMParameter:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim = getattr(param, "input_dim", None)
curr_output_dim = getattr(param, "output_dim", None)
if curr_input_dim is None or curr_output_dim is None:
assert param.data.dim() == 2,\
"permute_param_layout_ only supports 2D parameters when either "\
"input_dim or output_dim is not set"
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if curr_input_dim is None:
assert curr_output_dim is not None,\
"either input or output dim must be set"
curr_input_dim = (curr_output_dim + 1) % 2
if curr_output_dim is None:
assert curr_input_dim is not None,\
"either input or output dim must be set"
curr_output_dim = (curr_input_dim + 1) % 2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm = [
i for i in range(param.data.dim())
if i not in [curr_input_dim, curr_output_dim]
]
perm.insert(input_dim, curr_input_dim)
perm.insert(output_dim, curr_output_dim)
if "packed_dim" in kwargs:
assert hasattr(param, "packed_dim") and\
param.packed_dim == perm[kwargs["packed_dim"]],\
"permute_param_layout_ currently doesn't support repacking"
param.data = param.data.permute(*perm)
if hasattr(param, "_input_dim"):
param._input_dim = input_dim
if hasattr(param, "_output_dim"):
param._output_dim = output_dim
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
param._packed_dim = kwargs["packed_dim"]
return param
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size