[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ee5f34b1c2
commit
86e9c8df29
@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
|
@ -4,8 +4,10 @@ import itertools
|
||||
import math
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
from itertools import product
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
@ -84,6 +86,10 @@ def loop_over_weights(
|
||||
fn(a, w_ref, w_q, w_s)
|
||||
|
||||
|
||||
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
||||
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
||||
|
||||
|
||||
def bench(atype: torch.dtype,
|
||||
wtype: ScalarType,
|
||||
group_size: int,
|
||||
@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
|
||||
sub_label: str,
|
||||
benchmark_marlinv1: bool = True,
|
||||
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
||||
sub_label += f", L={len(weights)}"
|
||||
|
||||
@ -163,6 +171,11 @@ def bench(atype: torch.dtype,
|
||||
best_schedule = None
|
||||
schedules = ops.machete_supported_schedules(wtype)
|
||||
for schedule in reversed(schedules):
|
||||
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
||||
|
||||
# Prune known bad schedules
|
||||
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||
continue
|
||||
|
||||
def run(a, _, w_q, w_s, schedule=schedule):
|
||||
ops.machete_gemm(a,
|
||||
@ -175,6 +188,20 @@ def bench(atype: torch.dtype,
|
||||
res = bench_fn(label, sub_label, "machete_best",
|
||||
lambda: loop_over_weights(a, weights_machete, run))
|
||||
|
||||
results_row = {
|
||||
"M": m,
|
||||
"K": k,
|
||||
"N": n,
|
||||
"group_size": group_size,
|
||||
"schedule": schedule,
|
||||
"median": res.median,
|
||||
}
|
||||
if _SWEEP_SCHEDULES_RESULTS is None:
|
||||
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
|
||||
columns=results_row.keys())
|
||||
_SWEEP_SCHEDULES_RESULTS.\
|
||||
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||
|
||||
print(f" {res.median:5.5} ", schedule)
|
||||
if not best or res.median < best.median:
|
||||
best = res
|
||||
@ -235,18 +262,22 @@ def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||
n = len(dim_sizes)
|
||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
|
||||
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
|
||||
m_increment, k_increment, n_increment = \
|
||||
[int(x) for x in args.dim_increment.split(",")]
|
||||
Ms = list(range(m_start, m_end + 1, m_increment))
|
||||
Ks = list(range(k_start, k_end + 1, k_increment))
|
||||
Ns = list(range(n_start, n_end + 1, n_increment))
|
||||
MKNs = list(product(Ms, Ks, Ns))
|
||||
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
@ -333,6 +364,9 @@ Benchmark Machete GEMM.
|
||||
action="store_true",
|
||||
help="Run a sweep over all supported schedules",
|
||||
)
|
||||
parser.add_argument("--sweep-csv-out",
|
||||
help="CSV to store sweep results",
|
||||
default="sch_sweep_results.csv")
|
||||
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
@ -342,12 +376,21 @@ Benchmark Machete GEMM.
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||
range_parser.add_argument(
|
||||
"--dim-start",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Start value for M,K,N as common separated list")
|
||||
range_parser.add_argument(
|
||||
"--dim-end",
|
||||
type=str,
|
||||
required=True,
|
||||
help="End value (inclusive) for M,K,N as common separated list")
|
||||
range_parser.add_argument(
|
||||
"--dim-increment",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Increment value for M,K,N as common separated list")
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
@ -369,4 +412,9 @@ Benchmark Machete GEMM.
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
|
||||
args.func(args)
|
||||
|
||||
if _SWEEP_SCHEDULES_RESULTS is not None:
|
||||
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
|
||||
|
1
benchmarks/kernels/requirements.txt
Normal file
1
benchmarks/kernels/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
pandas
|
@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
if (tensor.size(idx) == 1) {
|
||||
// use 0 stride for dim with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
|
@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
|
||||
}; // namespace machete
|
||||
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
|
88
csrc/permute_cols.cu
Normal file
88
csrc/permute_cols.cu
Normal file
@ -0,0 +1,88 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
static constexpr int default_threads = 256;
|
||||
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
// Currently only supports 16bit types (since we permute half types)
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int block_rows) {
|
||||
int start_row = block_rows * blockIdx.x;
|
||||
int finish_row = start_row + block_rows;
|
||||
if (finish_row > size_m) {
|
||||
finish_row = size_m;
|
||||
}
|
||||
int cur_block_rows = std::max(finish_row - start_row, 0);
|
||||
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int offset = row * row_stride;
|
||||
|
||||
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < cur_block_rows; i++) {
|
||||
int cur_row = start_row + i;
|
||||
if (cur_row < size_m) {
|
||||
permute_row(cur_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// More efficient version of A[..., perm]
|
||||
// taken from gptq_marlin.cu
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto dev = A.get_device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
|
||||
"Currently only 16bit types are supported");
|
||||
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||
TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||
"A columns must be a multiple of 8 (128bits)");
|
||||
auto A_2d = A.view({-1, A.size(-1)});
|
||||
|
||||
torch::Tensor D = torch::empty_like(A);
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
int block_rows = div_ceil(A_2d.size(0), sms);
|
||||
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
|
||||
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
|
||||
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
|
||||
A_2d.size(0), A_2d.size(1), block_rows);
|
||||
return D;
|
||||
}
|
@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: Tuple[int, int]
|
||||
cluster_shape_mnk: Tuple[int, int, int]
|
||||
@ -328,56 +328,137 @@ def generate():
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mnk=cluster_shape_mnk,
|
||||
kernel_schedule=kernel_schedule,
|
||||
epilogue_schedule=epilogue_schedule,
|
||||
tile_scheduler=tile_scheduler,
|
||||
) for tile_shape_mn, cluster_shape_mnk in (
|
||||
((128, 16), (1, 1, 1)),
|
||||
((128, 32), (1, 1, 1)),
|
||||
((128, 64), (1, 1, 1)),
|
||||
((128, 128), (1, 1, 1)),
|
||||
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
|
||||
for tile_scheduler in (TileSchedulerType.StreamK, )
|
||||
]
|
||||
schedule_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
("M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
(None,
|
||||
ScheduleConfig(tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK))
|
||||
#### M = 257+
|
||||
(
|
||||
"M > 256 && K <= 16384 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 256",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 129-256
|
||||
(
|
||||
"M > 128 && K <= 4096 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128 && K <= 8192 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 65-128
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 4069",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K >= 8192 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 33-64
|
||||
(
|
||||
"M > 32 && K <= 6144 && N <= 6144",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32 && K >= 16384 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 17-32
|
||||
(
|
||||
"M > 16 && K <= 12288 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 1-16
|
||||
(
|
||||
"N >= 26624",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
None,
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
]
|
||||
|
||||
schedules = list(set([x[1] for x in default_heuristic]))
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
|
@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
|
||||
|
||||
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
||||
|
||||
int const group_size = maybe_group_size.value_or(K);
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
int const scale_k = (K + group_size - 1) / group_size;
|
||||
|
||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
|
@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
||||
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
||||
args.group_size.value_or(K));
|
||||
args.group_size);
|
||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
|
@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
// clang-format on
|
||||
|
||||
// Allocate output
|
||||
torch::Tensor D = torch::empty_like(B);
|
||||
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
||||
|
||||
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
|
||||
static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||
|
@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"-> Tensor");
|
||||
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
|
@ -31,6 +31,8 @@ MNK_SHAPES = [
|
||||
(257, 4224, 4160),
|
||||
(257, 4096, 4096),
|
||||
(64, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
ACT_TYPES = [torch.float16, torch.bfloat16]
|
||||
@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
for schedule in ops.machete_supported_schedules(wtype):
|
||||
print(f"Testing schedule {schedule}")
|
||||
output = ops.machete_gemm(
|
||||
a,
|
||||
b_q=w_q_machete,
|
||||
|
15
tests/kernels/test_permute_cols.py
Normal file
15
tests/kernels/test_permute_cols.py
Normal file
@ -0,0 +1,15 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import permute_cols
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
|
||||
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
|
||||
def test_permute_cols(shape, dtype):
|
||||
x = torch.randn(shape, dtype=dtype).cuda()
|
||||
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
|
||||
opcheck(torch.ops._C.permute_cols, (x, perm))
|
||||
y = permute_cols(x, perm)
|
||||
torch.testing.assert_close(y, x[:, perm])
|
@ -438,7 +438,8 @@ try:
|
||||
@torch.library.register_fake("_C::machete_prepack_B")
|
||||
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
||||
b_type: ScalarType) -> torch.Tensor:
|
||||
return torch.empty_like(b_q_weight)
|
||||
return torch.empty_like(b_q_weight,
|
||||
memory_format=torch.contiguous_format)
|
||||
|
||||
@torch.library.register_fake("_C::causal_conv1d_fwd")
|
||||
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||
@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
|
||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
||||
|
||||
|
||||
# TODO: has to be a better way to do this
|
||||
try:
|
||||
torch.ops._C.permute_cols # noqa B018
|
||||
|
||||
@torch.library.register_fake("_C::permute_cols")
|
||||
def _permute_cols_fake(a: torch.Tensor,
|
||||
perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(a)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.permute_cols(a, perm)
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
|
@ -7,10 +7,11 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
replace_parameter(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size)
|
||||
replace_tensor(layer, "scales", marlin_scales)
|
||||
replace_parameter(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
marlin_zp = awq_to_marlin_zero_points(
|
||||
@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_tensor(layer, "qzeros", marlin_zp)
|
||||
replace_parameter(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
|
@ -1,17 +1,16 @@
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
ActivationOrdering)
|
||||
from vllm.model_executor.layers.quantization.kernels import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(quant_type=self.quant_type,
|
||||
group_size=self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.has_g_idx
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsWNA16",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.input_size = input_size
|
||||
layer.group_size = group_size
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name=None,
|
||||
w_gidx_param_name="weight_g_idx")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from marlin format. Handle repacking here.
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.weight_packed.device
|
||||
|
||||
# Allocate marlin workspace.
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
|
||||
# Handle sorting for activation reordering if needed.
|
||||
if self.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
replace_tensor(layer, "weight_g_idx", g_idx)
|
||||
else:
|
||||
layer.weight_g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
# No zero-point
|
||||
layer.weight_zp = marlin_make_empty_g_idx(device)
|
||||
# Update for kernel
|
||||
layer.weight_packed = torch.nn.Parameter(
|
||||
layer.weight_packed.t().contiguous(), requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
|
||||
|
||||
# Repack weights from compressed-tensors format to marlin format.
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.weight_packed,
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_type.size_bits)
|
||||
replace_tensor(layer, "weight_packed", marlin_qweight)
|
||||
|
||||
# Permute scales from compressed-tensors format to marlin format.
|
||||
# scale is required on all partitions if activation reordering
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.weight_scale,
|
||||
size_k=(layer.input_size
|
||||
if self.has_g_idx else layer.input_size_per_partition),
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=layer.group_size)
|
||||
replace_tensor(layer, "weight_scale", marlin_scales)
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_zp=layer.weight_zp,
|
||||
g_idx=layer.weight_g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
wtype=self.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=True,
|
||||
bias=bias)
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
check_marlin_supported, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQMarlinLinearMethod",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.input_size = input_size
|
||||
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
|
||||
is_row_parallel)
|
||||
|
||||
# Checkpoints are serialized in AutoGPTQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking, including the activation reordering case.
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
|
||||
# required by torch.compile
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
|
||||
# Handle sorting for activation reordering if needed.
|
||||
if self.quant_config.desc_act:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
replace_tensor(layer, "g_idx", g_idx)
|
||||
else:
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
# No zero-point
|
||||
layer.zp = marlin_make_empty_g_idx(device)
|
||||
|
||||
# Repack weights from autogptq format to marlin format.
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.qweight,
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from autogptq format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=(layer.input_size if self.quant_config.desc_act else
|
||||
layer.input_size_per_partition),
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_tensor(layer, "scales", marlin_scales)
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
weight_zp=layer.zp,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
wtype=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=layer.is_k_full,
|
||||
bias=bias,
|
||||
)
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
||||
w2_g_idx_sort_indices[e]]
|
||||
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_tensor(layer, "w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
replace_tensor(layer, "w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_tensor(layer, "w13_scales", marlin_w13_scales)
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_tensor(layer, "w2_scales", marlin_w2_scales)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
@ -0,0 +1,83 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: Tuple[int, int] # [in, out]
|
||||
partition_weight_shape: Tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
self.w_zp_name = w_zp_param_name
|
||||
self.w_gidx_name = w_gidx_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
|
||||
fn: Callable) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name,
|
||||
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Tuple[torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor] # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
72
vllm/model_executor/layers/quantization/kernels/__init__.py
Normal file
72
vllm/model_executor/layers/quantization/kernels/__init__.py
Normal file
@ -0,0 +1,72 @@
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.machete import (
|
||||
MacheteLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.marlin import (
|
||||
MarlinLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
|
||||
MPLinearKernel, MPLinearLayerConfig)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
]
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the compute
|
||||
capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
Type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
||||
.split(","):
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
continue
|
||||
|
||||
if kernel.get_min_capability() > compute_capability:
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute capability "
|
||||
f"is {compute_capability}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"WNA16 linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
118
vllm/model_executor/layers/quantization/kernels/machete.py
Normal file
118
vllm/model_executor/layers/quantization/kernels/machete.py
Normal file
@ -0,0 +1,118 @@
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
|
||||
query_machete_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_weights_into_int32, unpack_weights_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Machete, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by "\
|
||||
" Compressed Tensors + Machete. (Kernel supports it"\
|
||||
" but CompressedTensorsWNA16 does not so support has"\
|
||||
" not been added to MacheteWNA16Kernel yet"
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(
|
||||
c.zero_points):
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Machete, supported types are: "\
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}"
|
||||
|
||||
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Machete, supported group sizes are: "\
|
||||
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
|
||||
|
||||
return check_machete_supports_shape(c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
|
||||
.to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if c.act_type in [torch.float16, torch.bfloat16] \
|
||||
and c.partition_weight_shape[0] % 8 == 0:
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_weights_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_weights_into_int32(x_perm,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
self.config.weight_type)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
output = ops.machete_gemm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_zeros=None,
|
||||
b_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
132
vllm/model_executor/layers/quantization/kernels/marlin.py
Normal file
132
vllm/model_executor/layers/quantization/kernels/marlin.py
Normal file
@ -0,0 +1,132 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||
query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by "\
|
||||
" MarlinLinearKernel. Will be added when AWQMarlin "\
|
||||
"is migrated over to using MPLinearKernel backend"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by"\
|
||||
f" Marlin, supported types are: {quant_types}"
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Marlin, supported group sizes are: "\
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
|
||||
|
||||
return check_marlin_supports_shape(c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
c.full_weight_shape[1],
|
||||
c.group_size)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
||||
device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
pass
|
||||
# TODO (lucas): add the following when AWQMarlin is migrated over to
|
||||
# using MPLinearKernel backend
|
||||
# self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
# marlin_zero_points(
|
||||
# x,
|
||||
# size_k=c.partition_weight_shape[0],
|
||||
# size_n=c.partition_weight_shape[1],
|
||||
# num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size)
|
||||
return x
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
@ -0,0 +1,3 @@
|
||||
from .layer_utils import replace_parameter, update_tensor_inplace
|
||||
|
||||
__all__ = ['update_tensor_inplace', 'replace_parameter']
|
33
vllm/model_executor/layers/quantization/utils/layer_utils.py
Normal file
33
vllm/model_executor/layers/quantization/utils/layer_utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
|
||||
assert dst.dtype == src.dtype, "Tensors must have the same dtype"
|
||||
|
||||
# update tensor shape and stride
|
||||
dst.as_strided_(src.shape, src.stride())
|
||||
|
||||
# If not the same underlying storage move tensor data
|
||||
if dst.data_ptr() != src.data_ptr():
|
||||
dst.copy_(src)
|
||||
del src
|
||||
|
||||
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_parameter(mod: torch.nn.Module, name: str,
|
||||
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
|
||||
|
||||
old = getattr(mod, name)
|
||||
if old.dtype == new.dtype and \
|
||||
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
|
||||
# If we can just update in-place to avoid re-registering
|
||||
# can be faster if the underlying storage is the same
|
||||
update_tensor_inplace(old, new)
|
||||
else:
|
||||
# Fallback re-register parameter
|
||||
if not isinstance(new, torch.nn.Parameter):
|
||||
new = torch.nn.Parameter(new)
|
||||
mod.register_parameter(name, torch.nn.Parameter(new))
|
@ -0,0 +1,30 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
|
||||
|
||||
|
||||
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
|
||||
if zero_points:
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def check_machete_supports_shape(in_features: int, out_featrues: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
|
||||
return False, "Input features size must be divisible by "\
|
||||
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
|
||||
if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
|
||||
return False, "Output features size must be divisible by "\
|
||||
f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}"
|
||||
return True, None
|
@ -120,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
|
||||
"with --quantization gptq.")
|
||||
|
||||
|
||||
def check_marlin_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_marlin_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
group_size)
|
||||
except ValueError as e:
|
||||
return False, e.__str__()
|
||||
return True, None
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
@ -148,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
@ -240,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return marlin_zp
|
||||
|
||||
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_tensor(layer: torch.nn.Module, name: str,
|
||||
new_t: torch.Tensor) -> None:
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
|
@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def pack_weights_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
# move dim to pack to the end
|
||||
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||
w_q_perm = w_q.permute(perm)
|
||||
|
||||
pack_factor = 32 // wtype.size_bits
|
||||
mask = (1 << wtype.size_bits) - 1
|
||||
|
||||
new_shape_perm = list(w_q_perm.shape)
|
||||
assert w_q_perm.shape[-1] % pack_factor == 0
|
||||
new_shape_perm[-1] //= pack_factor
|
||||
|
||||
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
||||
for i in range(pack_factor):
|
||||
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
||||
|
||||
return res.permute(inv_perm)
|
||||
|
||||
|
||||
def unpack_weights_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
# move dim to pack to the end
|
||||
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||
w_q_perm = w_q.permute(perm)
|
||||
|
||||
pack_factor = 32 // wtype.size_bits
|
||||
mask = (1 << wtype.size_bits) - 1
|
||||
|
||||
new_shape_perm = list(w_q_perm.shape)
|
||||
new_shape_perm[-1] *= pack_factor
|
||||
|
||||
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
||||
for i in range(pack_factor):
|
||||
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
||||
|
||||
return res.permute(inv_perm)
|
||||
|
||||
|
||||
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
|
@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
|
||||
|
||||
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
output_dim: int, **kwargs) -> BasevLLMParameter:
|
||||
"""
|
||||
Permute a parameter's layout to the specified input and output dimensions,
|
||||
useful for forcing the parameter into a known layout, for example, if I need
|
||||
a packed (quantized) weight matrix to be in the layout
|
||||
{input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
then I can call:
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||
required, asserting if it cannot get it to the correct layout)
|
||||
"""
|
||||
|
||||
curr_input_dim = getattr(param, "input_dim", None)
|
||||
curr_output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
if curr_input_dim is None or curr_output_dim is None:
|
||||
assert param.data.dim() == 2,\
|
||||
"permute_param_layout_ only supports 2D parameters when either "\
|
||||
"input_dim or output_dim is not set"
|
||||
|
||||
# if one of the dimensions is not set, set it to the opposite of the other
|
||||
# we can only do this since we asserted the parameter is 2D above
|
||||
if curr_input_dim is None:
|
||||
assert curr_output_dim is not None,\
|
||||
"either input or output dim must be set"
|
||||
curr_input_dim = (curr_output_dim + 1) % 2
|
||||
if curr_output_dim is None:
|
||||
assert curr_input_dim is not None,\
|
||||
"either input or output dim must be set"
|
||||
curr_output_dim = (curr_input_dim + 1) % 2
|
||||
|
||||
# create permutation from the current layout to the layout with
|
||||
# self.input_dim at input_dim and self.output_dim at output_dim preserving
|
||||
# other dimensions
|
||||
perm = [
|
||||
i for i in range(param.data.dim())
|
||||
if i not in [curr_input_dim, curr_output_dim]
|
||||
]
|
||||
perm.insert(input_dim, curr_input_dim)
|
||||
perm.insert(output_dim, curr_output_dim)
|
||||
|
||||
if "packed_dim" in kwargs:
|
||||
assert hasattr(param, "packed_dim") and\
|
||||
param.packed_dim == perm[kwargs["packed_dim"]],\
|
||||
"permute_param_layout_ currently doesn't support repacking"
|
||||
|
||||
param.data = param.data.permute(*perm)
|
||||
if hasattr(param, "_input_dim"):
|
||||
param._input_dim = input_dim
|
||||
if hasattr(param, "_output_dim"):
|
||||
param._output_dim = output_dim
|
||||
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
|
||||
param._packed_dim = kwargs["packed_dim"]
|
||||
|
||||
return param
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||
marlin_tile_size):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
Loading…
x
Reference in New Issue
Block a user