[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ee5f34b1c2
commit
86e9c8df29
@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||||
|
@ -4,8 +4,10 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Iterable, List, Tuple
|
from itertools import product
|
||||||
|
from typing import Callable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
from torch.utils.benchmark import Measurement as TMeasurement
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
@ -84,6 +86,10 @@ def loop_over_weights(
|
|||||||
fn(a, w_ref, w_q, w_s)
|
fn(a, w_ref, w_q, w_s)
|
||||||
|
|
||||||
|
|
||||||
|
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
||||||
|
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def bench(atype: torch.dtype,
|
def bench(atype: torch.dtype,
|
||||||
wtype: ScalarType,
|
wtype: ScalarType,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
|
|||||||
sub_label: str,
|
sub_label: str,
|
||||||
benchmark_marlinv1: bool = True,
|
benchmark_marlinv1: bool = True,
|
||||||
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
||||||
|
global _SWEEP_SCHEDULES_RESULTS
|
||||||
|
|
||||||
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
||||||
sub_label += f", L={len(weights)}"
|
sub_label += f", L={len(weights)}"
|
||||||
|
|
||||||
@ -163,6 +171,11 @@ def bench(atype: torch.dtype,
|
|||||||
best_schedule = None
|
best_schedule = None
|
||||||
schedules = ops.machete_supported_schedules(wtype)
|
schedules = ops.machete_supported_schedules(wtype)
|
||||||
for schedule in reversed(schedules):
|
for schedule in reversed(schedules):
|
||||||
|
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
||||||
|
|
||||||
|
# Prune known bad schedules
|
||||||
|
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||||
|
continue
|
||||||
|
|
||||||
def run(a, _, w_q, w_s, schedule=schedule):
|
def run(a, _, w_q, w_s, schedule=schedule):
|
||||||
ops.machete_gemm(a,
|
ops.machete_gemm(a,
|
||||||
@ -175,6 +188,20 @@ def bench(atype: torch.dtype,
|
|||||||
res = bench_fn(label, sub_label, "machete_best",
|
res = bench_fn(label, sub_label, "machete_best",
|
||||||
lambda: loop_over_weights(a, weights_machete, run))
|
lambda: loop_over_weights(a, weights_machete, run))
|
||||||
|
|
||||||
|
results_row = {
|
||||||
|
"M": m,
|
||||||
|
"K": k,
|
||||||
|
"N": n,
|
||||||
|
"group_size": group_size,
|
||||||
|
"schedule": schedule,
|
||||||
|
"median": res.median,
|
||||||
|
}
|
||||||
|
if _SWEEP_SCHEDULES_RESULTS is None:
|
||||||
|
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
|
||||||
|
columns=results_row.keys())
|
||||||
|
_SWEEP_SCHEDULES_RESULTS.\
|
||||||
|
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||||
|
|
||||||
print(f" {res.median:5.5} ", schedule)
|
print(f" {res.median:5.5} ", schedule)
|
||||||
if not best or res.median < best.median:
|
if not best or res.median < best.median:
|
||||||
best = res
|
best = res
|
||||||
@ -235,18 +262,22 @@ def run_square_bench(args):
|
|||||||
dim_sizes = list(
|
dim_sizes = list(
|
||||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
|
||||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
def run_range_bench(args):
|
def run_range_bench(args):
|
||||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
|
||||||
n = len(dim_sizes)
|
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
|
||||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
m_increment, k_increment, n_increment = \
|
||||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
[int(x) for x in args.dim_increment.split(",")]
|
||||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
Ms = list(range(m_start, m_end + 1, m_increment))
|
||||||
MKNs = list(zip(Ms, Ks, Ns))
|
Ks = list(range(k_start, k_end + 1, k_increment))
|
||||||
|
Ns = list(range(n_start, n_end + 1, n_increment))
|
||||||
|
MKNs = list(product(Ms, Ks, Ns))
|
||||||
|
|
||||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
@ -333,6 +364,9 @@ Benchmark Machete GEMM.
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Run a sweep over all supported schedules",
|
help="Run a sweep over all supported schedules",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--sweep-csv-out",
|
||||||
|
help="CSV to store sweep results",
|
||||||
|
default="sch_sweep_results.csv")
|
||||||
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||||
|
|
||||||
square_parser = subparsers.add_parser("square_bench")
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
@ -342,12 +376,21 @@ Benchmark Machete GEMM.
|
|||||||
square_parser.set_defaults(func=run_square_bench)
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
range_parser = subparsers.add_parser("range_bench")
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
range_parser.add_argument(
|
||||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
"--dim-start",
|
||||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
type=str,
|
||||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
required=True,
|
||||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
help="Start value for M,K,N as common separated list")
|
||||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
range_parser.add_argument(
|
||||||
|
"--dim-end",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="End value (inclusive) for M,K,N as common separated list")
|
||||||
|
range_parser.add_argument(
|
||||||
|
"--dim-increment",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Increment value for M,K,N as common separated list")
|
||||||
range_parser.set_defaults(func=run_range_bench)
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
model_parser = subparsers.add_parser("model_bench")
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
@ -369,4 +412,9 @@ Benchmark Machete GEMM.
|
|||||||
model_parser.set_defaults(func=run_model_bench)
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
||||||
|
if _SWEEP_SCHEDULES_RESULTS is not None:
|
||||||
|
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
|
||||||
|
1
benchmarks/kernels/requirements.txt
Normal file
1
benchmarks/kernels/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
pandas
|
@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
|||||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||||
return StrideEle{};
|
return StrideEle{};
|
||||||
} else {
|
} else {
|
||||||
return tensor.stride(idx);
|
if (tensor.size(idx) == 1) {
|
||||||
|
// use 0 stride for dim with size 1, this is easier for
|
||||||
|
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||||
|
return StrideEle{0};
|
||||||
|
} else {
|
||||||
|
return tensor.stride(idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Extra strides are assumed to be 0 or 1
|
// Extra strides are assumed to be 0 or 1
|
||||||
|
@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
|
|||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
||||||
|
|
||||||
|
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& b_meta,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor& b_scales,
|
torch::Tensor& b_scales,
|
||||||
|
88
csrc/permute_cols.cu
Normal file
88
csrc/permute_cols.cu
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
static constexpr int default_threads = 256;
|
||||||
|
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||||
|
// on the given "perm" indices.
|
||||||
|
// Currently only supports 16bit types (since we permute half types)
|
||||||
|
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||||
|
int const* __restrict__ perm_int_ptr,
|
||||||
|
int4* __restrict__ out_int4_ptr, int size_m,
|
||||||
|
int size_k, int block_rows) {
|
||||||
|
int start_row = block_rows * blockIdx.x;
|
||||||
|
int finish_row = start_row + block_rows;
|
||||||
|
if (finish_row > size_m) {
|
||||||
|
finish_row = size_m;
|
||||||
|
}
|
||||||
|
int cur_block_rows = std::max(finish_row - start_row, 0);
|
||||||
|
|
||||||
|
int row_stride = size_k * sizeof(half) / 16;
|
||||||
|
|
||||||
|
auto permute_row = [&](int row) {
|
||||||
|
int iters = size_k / default_threads;
|
||||||
|
int rest = size_k % default_threads;
|
||||||
|
|
||||||
|
int offset = row * row_stride;
|
||||||
|
|
||||||
|
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
||||||
|
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
||||||
|
|
||||||
|
int base_k = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < iters; i++) {
|
||||||
|
int cur_k = base_k + threadIdx.x;
|
||||||
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
|
|
||||||
|
base_k += default_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rest) {
|
||||||
|
if (threadIdx.x < rest) {
|
||||||
|
int cur_k = base_k + threadIdx.x;
|
||||||
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < cur_block_rows; i++) {
|
||||||
|
int cur_row = start_row + i;
|
||||||
|
if (cur_row < size_m) {
|
||||||
|
permute_row(cur_row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// More efficient version of A[..., perm]
|
||||||
|
// taken from gptq_marlin.cu
|
||||||
|
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
|
auto dev = A.get_device();
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
|
||||||
|
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
|
||||||
|
"Currently only 16bit types are supported");
|
||||||
|
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||||
|
TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||||
|
"A columns must be a multiple of 8 (128bits)");
|
||||||
|
auto A_2d = A.view({-1, A.size(-1)});
|
||||||
|
|
||||||
|
torch::Tensor D = torch::empty_like(A);
|
||||||
|
int sms;
|
||||||
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
int block_rows = div_ceil(A_2d.size(0), sms);
|
||||||
|
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
|
||||||
|
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
|
||||||
|
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
|
||||||
|
A_2d.size(0), A_2d.size(1), block_rows);
|
||||||
|
return D;
|
||||||
|
}
|
@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
|||||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class ScheduleConfig:
|
class ScheduleConfig:
|
||||||
tile_shape_mn: Tuple[int, int]
|
tile_shape_mn: Tuple[int, int]
|
||||||
cluster_shape_mnk: Tuple[int, int, int]
|
cluster_shape_mnk: Tuple[int, int, int]
|
||||||
@ -328,56 +328,137 @@ def generate():
|
|||||||
# about how this works
|
# about how this works
|
||||||
SCRIPT_DIR = os.path.dirname(__file__)
|
SCRIPT_DIR = os.path.dirname(__file__)
|
||||||
|
|
||||||
schedules = [
|
schedule_common_params = dict(
|
||||||
ScheduleConfig(
|
kernel_schedule=TmaMI,
|
||||||
tile_shape_mn=tile_shape_mn,
|
epilogue_schedule=TmaCoop,
|
||||||
cluster_shape_mnk=cluster_shape_mnk,
|
tile_scheduler=TileSchedulerType.StreamK,
|
||||||
kernel_schedule=kernel_schedule,
|
)
|
||||||
epilogue_schedule=epilogue_schedule,
|
|
||||||
tile_scheduler=tile_scheduler,
|
|
||||||
) for tile_shape_mn, cluster_shape_mnk in (
|
|
||||||
((128, 16), (1, 1, 1)),
|
|
||||||
((128, 32), (1, 1, 1)),
|
|
||||||
((128, 64), (1, 1, 1)),
|
|
||||||
((128, 128), (1, 1, 1)),
|
|
||||||
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
|
|
||||||
for tile_scheduler in (TileSchedulerType.StreamK, )
|
|
||||||
]
|
|
||||||
|
|
||||||
# For now we use the same heuristic for all types
|
# For now we use the same heuristic for all types
|
||||||
|
# Heuristic is currently tuned for H100s
|
||||||
default_heuristic = [
|
default_heuristic = [
|
||||||
("M > 64",
|
#### M = 257+
|
||||||
ScheduleConfig(
|
(
|
||||||
tile_shape_mn=(128, 128),
|
"M > 256 && K <= 16384 && N <= 4096",
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
ScheduleConfig(
|
||||||
kernel_schedule=TmaMI,
|
tile_shape_mn=(128, 128),
|
||||||
epilogue_schedule=TmaCoop,
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
tile_scheduler=TileSchedulerType.StreamK,
|
**schedule_common_params # type: ignore
|
||||||
)),
|
)),
|
||||||
("M > 32",
|
(
|
||||||
ScheduleConfig(
|
"M > 256",
|
||||||
tile_shape_mn=(128, 64),
|
ScheduleConfig(
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
tile_shape_mn=(128, 256),
|
||||||
kernel_schedule=TmaMI,
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
epilogue_schedule=TmaCoop,
|
**schedule_common_params # type: ignore
|
||||||
tile_scheduler=TileSchedulerType.StreamK,
|
)),
|
||||||
)),
|
#### M = 129-256
|
||||||
("M > 16",
|
(
|
||||||
ScheduleConfig(
|
"M > 128 && K <= 4096 && N <= 4096",
|
||||||
tile_shape_mn=(128, 32),
|
ScheduleConfig(
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
tile_shape_mn=(128, 64),
|
||||||
kernel_schedule=TmaMI,
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
epilogue_schedule=TmaCoop,
|
**schedule_common_params # type: ignore
|
||||||
tile_scheduler=TileSchedulerType.StreamK,
|
)),
|
||||||
)),
|
(
|
||||||
(None,
|
"M > 128 && K <= 8192 && N <= 8192",
|
||||||
ScheduleConfig(tile_shape_mn=(128, 16),
|
ScheduleConfig(
|
||||||
cluster_shape_mnk=(1, 1, 1),
|
tile_shape_mn=(128, 128),
|
||||||
kernel_schedule=TmaMI,
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
epilogue_schedule=TmaCoop,
|
**schedule_common_params # type: ignore
|
||||||
tile_scheduler=TileSchedulerType.StreamK))
|
)),
|
||||||
|
(
|
||||||
|
"M > 128",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 256),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
#### M = 65-128
|
||||||
|
(
|
||||||
|
"M > 64 && K <= 4069 && N <= 4069",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 32),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
"M > 64 && K <= 4069 && N <= 8192",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 64),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
"M > 64 && K >= 8192 && N >= 12288",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(256, 128),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
"M > 64",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 128),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
#### M = 33-64
|
||||||
|
(
|
||||||
|
"M > 32 && K <= 6144 && N <= 6144",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 16),
|
||||||
|
cluster_shape_mnk=(1, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
"M > 32 && K >= 16384 && N >= 12288",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(256, 64),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
"M > 32",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 64),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
#### M = 17-32
|
||||||
|
(
|
||||||
|
"M > 16 && K <= 12288 && N <= 8192",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 32),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
"M > 16",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(256, 32),
|
||||||
|
cluster_shape_mnk=(2, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
#### M = 1-16
|
||||||
|
(
|
||||||
|
"N >= 26624",
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(256, 16),
|
||||||
|
cluster_shape_mnk=(1, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
ScheduleConfig(
|
||||||
|
tile_shape_mn=(128, 16),
|
||||||
|
cluster_shape_mnk=(1, 1, 1),
|
||||||
|
**schedule_common_params # type: ignore
|
||||||
|
)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
schedules = list(set([x[1] for x in default_heuristic]))
|
||||||
|
|
||||||
impl_configs = []
|
impl_configs = []
|
||||||
|
|
||||||
GPTQ_kernel_type_configs = list(
|
GPTQ_kernel_type_configs = list(
|
||||||
|
@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
|
|||||||
|
|
||||||
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
||||||
|
|
||||||
int const group_size = maybe_group_size.value_or(K);
|
int const group_size =
|
||||||
|
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||||
int const scale_k = (K + group_size - 1) / group_size;
|
int const scale_k = (K + group_size - 1) / group_size;
|
||||||
|
|
||||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||||
|
@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
|
|||||||
auto arguments = MacheteKernel::create_arguments(
|
auto arguments = MacheteKernel::create_arguments(
|
||||||
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
||||||
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
||||||
args.group_size.value_or(K));
|
args.group_size);
|
||||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||||
"Machete kernel cannot be run with these arguments");
|
"Machete kernel cannot be run with these arguments");
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
torch::Tensor D = torch::empty_like(B);
|
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
||||||
|
|
||||||
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
|
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
|
||||||
static_cast<ElementB*>(D.mutable_data_ptr()));
|
static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||||
|
@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"-> Tensor");
|
"-> Tensor");
|
||||||
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
||||||
|
|
||||||
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||||
|
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||||
|
|
||||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
|
@ -31,6 +31,8 @@ MNK_SHAPES = [
|
|||||||
(257, 4224, 4160),
|
(257, 4224, 4160),
|
||||||
(257, 4096, 4096),
|
(257, 4096, 4096),
|
||||||
(64, 4096, 4096),
|
(64, 4096, 4096),
|
||||||
|
(1024, 4096, 8192),
|
||||||
|
(1024, 8192, 4096),
|
||||||
]
|
]
|
||||||
|
|
||||||
ACT_TYPES = [torch.float16, torch.bfloat16]
|
ACT_TYPES = [torch.float16, torch.bfloat16]
|
||||||
@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
|
|||||||
output_ref = torch.matmul(a, w_ref)
|
output_ref = torch.matmul(a, w_ref)
|
||||||
|
|
||||||
for schedule in ops.machete_supported_schedules(wtype):
|
for schedule in ops.machete_supported_schedules(wtype):
|
||||||
|
print(f"Testing schedule {schedule}")
|
||||||
output = ops.machete_gemm(
|
output = ops.machete_gemm(
|
||||||
a,
|
a,
|
||||||
b_q=w_q_machete,
|
b_q=w_q_machete,
|
||||||
|
15
tests/kernels/test_permute_cols.py
Normal file
15
tests/kernels/test_permute_cols.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.utils import opcheck
|
||||||
|
from vllm._custom_ops import permute_cols
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
|
||||||
|
def test_permute_cols(shape, dtype):
|
||||||
|
x = torch.randn(shape, dtype=dtype).cuda()
|
||||||
|
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
|
||||||
|
opcheck(torch.ops._C.permute_cols, (x, perm))
|
||||||
|
y = permute_cols(x, perm)
|
||||||
|
torch.testing.assert_close(y, x[:, perm])
|
@ -438,7 +438,8 @@ try:
|
|||||||
@torch.library.register_fake("_C::machete_prepack_B")
|
@torch.library.register_fake("_C::machete_prepack_B")
|
||||||
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
||||||
b_type: ScalarType) -> torch.Tensor:
|
b_type: ScalarType) -> torch.Tensor:
|
||||||
return torch.empty_like(b_q_weight)
|
return torch.empty_like(b_q_weight,
|
||||||
|
memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::causal_conv1d_fwd")
|
@torch.library.register_fake("_C::causal_conv1d_fwd")
|
||||||
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||||
@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
|
|||||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: has to be a better way to do this
|
||||||
|
try:
|
||||||
|
torch.ops._C.permute_cols # noqa B018
|
||||||
|
|
||||||
|
@torch.library.register_fake("_C::permute_cols")
|
||||||
|
def _permute_cols_fake(a: torch.Tensor,
|
||||||
|
perm: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.empty_like(a)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.ops._C.permute_cols(a, perm)
|
||||||
|
|
||||||
|
|
||||||
# fp8
|
# fp8
|
||||||
def scaled_fp8_quant(
|
def scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
@ -7,10 +7,11 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
verify_marlin_supported, verify_marlin_supports_shape)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
PackedvLLMParameter)
|
PackedvLLMParameter)
|
||||||
@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
num_bits=self.quant_config.quant_type.size_bits)
|
num_bits=self.quant_config.quant_type.size_bits)
|
||||||
replace_tensor(layer, "qweight", marlin_qweight)
|
replace_parameter(layer, "qweight", marlin_qweight)
|
||||||
|
|
||||||
# Permute scales from AWQ format to marlin format.
|
# Permute scales from AWQ format to marlin format.
|
||||||
marlin_scales = marlin_permute_scales(
|
marlin_scales = marlin_permute_scales(
|
||||||
@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
group_size=self.quant_config.group_size)
|
group_size=self.quant_config.group_size)
|
||||||
replace_tensor(layer, "scales", marlin_scales)
|
replace_parameter(layer, "scales", marlin_scales)
|
||||||
|
|
||||||
# Permute zero-points from AWQ format to marlin format.
|
# Permute zero-points from AWQ format to marlin format.
|
||||||
marlin_zp = awq_to_marlin_zero_points(
|
marlin_zp = awq_to_marlin_zero_points(
|
||||||
@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
size_k=layer.num_groups,
|
size_k=layer.num_groups,
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
num_bits=self.quant_config.quant_type.size_bits)
|
num_bits=self.quant_config.quant_type.size_bits)
|
||||||
replace_tensor(layer, "qzeros", marlin_zp)
|
replace_parameter(layer, "qzeros", marlin_zp)
|
||||||
|
|
||||||
# Not-used
|
# Not-used
|
||||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
ActivationOrdering)
|
ActivationOrdering)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels import (
|
||||||
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
marlin_repeat_scales_on_all_ranks)
|
||||||
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
|
||||||
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
|
||||||
verify_marlin_supports_shape)
|
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
|||||||
RowvLLMParameter)
|
RowvLLMParameter)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsWNA16"]
|
__all__ = ["CompressedTensorsWNA16"]
|
||||||
WNA16_SUPPORTED_TYPES_MAP = {
|
WNA16_SUPPORTED_TYPES_MAP = {
|
||||||
4: scalar_types.uint4b8,
|
4: scalar_types.uint4b8,
|
||||||
@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
|||||||
|
|
||||||
|
|
||||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||||
|
_kernel_backends_being_used: Set[str] = set()
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
strategy: str,
|
strategy: str,
|
||||||
@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
|
|
||||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
||||||
|
|
||||||
# Verify supported on platform.
|
|
||||||
verify_marlin_supported(quant_type=self.quant_type,
|
|
||||||
group_size=self.group_size)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# ampere and up
|
# ampere and up
|
||||||
return 80
|
return 80
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||||
output_partition_sizes: List[int],
|
input_size: int, output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
|
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||||
|
full_weight_shape=(input_size, output_size),
|
||||||
|
partition_weight_shape=\
|
||||||
|
(input_size_per_partition, output_size_per_partition),
|
||||||
|
weight_type=self.quant_type,
|
||||||
|
act_type=params_dtype,
|
||||||
|
group_size=self.group_size,
|
||||||
|
zero_points=False,
|
||||||
|
has_g_idx=self.has_g_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||||
|
|
||||||
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for CompressedTensorsWNA16",
|
||||||
|
kernel_type.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
# If group_size is -1, we are in channelwise case.
|
# If group_size is -1, we are in channelwise case.
|
||||||
group_size = self.group_size if self.group_size != -1 else input_size
|
group_size = self.group_size if self.group_size != -1 else input_size
|
||||||
row_parallel = (input_size != input_size_per_partition)
|
row_parallel = (input_size != input_size_per_partition)
|
||||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||||
self.has_g_idx, self.group_size, row_parallel)
|
self.has_g_idx, self.group_size, row_parallel)
|
||||||
|
|
||||||
verify_marlin_supports_shape(
|
|
||||||
output_size_per_partition=output_size_per_partition,
|
|
||||||
input_size_per_partition=input_size_per_partition,
|
|
||||||
input_size=input_size,
|
|
||||||
group_size=group_size)
|
|
||||||
|
|
||||||
scales_and_zp_size = input_size // group_size
|
scales_and_zp_size = input_size // group_size
|
||||||
|
|
||||||
if partition_scales:
|
if partition_scales:
|
||||||
@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
w_q_param_name="weight_packed",
|
||||||
layer.input_size = input_size
|
w_s_param_name="weight_scale",
|
||||||
layer.group_size = group_size
|
w_zp_param_name=None,
|
||||||
|
w_gidx_param_name="weight_g_idx")
|
||||||
|
|
||||||
# Checkpoints are serialized in compressed-tensors format, which is
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
# different from marlin format. Handle repacking here.
|
# different from the format the kernel may want. Handle repacking here.
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
device = layer.weight_packed.device
|
self.kernel.process_weights_after_loading(layer)
|
||||||
|
|
||||||
# Allocate marlin workspace.
|
|
||||||
layer.workspace = marlin_make_workspace(
|
|
||||||
layer.output_size_per_partition, device)
|
|
||||||
|
|
||||||
# Handle sorting for activation reordering if needed.
|
|
||||||
if self.has_g_idx:
|
|
||||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
|
|
||||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
|
||||||
replace_tensor(layer, "weight_g_idx", g_idx)
|
|
||||||
else:
|
|
||||||
layer.weight_g_idx = marlin_make_empty_g_idx(device)
|
|
||||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
|
||||||
|
|
||||||
# No zero-point
|
|
||||||
layer.weight_zp = marlin_make_empty_g_idx(device)
|
|
||||||
# Update for kernel
|
|
||||||
layer.weight_packed = torch.nn.Parameter(
|
|
||||||
layer.weight_packed.t().contiguous(), requires_grad=False)
|
|
||||||
layer.weight_scale = torch.nn.Parameter(
|
|
||||||
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
|
|
||||||
|
|
||||||
# Repack weights from compressed-tensors format to marlin format.
|
|
||||||
marlin_qweight = ops.gptq_marlin_repack(
|
|
||||||
layer.weight_packed,
|
|
||||||
perm=layer.g_idx_sort_indices,
|
|
||||||
size_k=layer.input_size_per_partition,
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
num_bits=self.quant_type.size_bits)
|
|
||||||
replace_tensor(layer, "weight_packed", marlin_qweight)
|
|
||||||
|
|
||||||
# Permute scales from compressed-tensors format to marlin format.
|
|
||||||
# scale is required on all partitions if activation reordering
|
|
||||||
marlin_scales = marlin_permute_scales(
|
|
||||||
layer.weight_scale,
|
|
||||||
size_k=(layer.input_size
|
|
||||||
if self.has_g_idx else layer.input_size_per_partition),
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
group_size=layer.group_size)
|
|
||||||
replace_tensor(layer, "weight_scale", marlin_scales)
|
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
return self.kernel.apply_weights(layer, x, bias)
|
||||||
return apply_gptq_marlin_linear(
|
|
||||||
input=x,
|
|
||||||
weight=layer.weight_packed,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
weight_zp=layer.weight_zp,
|
|
||||||
g_idx=layer.weight_g_idx,
|
|
||||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
|
||||||
workspace=layer.workspace,
|
|
||||||
wtype=self.quant_type,
|
|
||||||
output_size_per_partition=layer.output_size_per_partition,
|
|
||||||
input_size_per_partition=layer.input_size_per_partition,
|
|
||||||
is_k_full=True,
|
|
||||||
bias=bias)
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels import (
|
||||||
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
check_marlin_supported, marlin_moe_permute_scales,
|
||||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||||
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
|
||||||
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
|
||||||
verify_marlin_supports_shape)
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
quant_config: The GPTQ Marlin quantization config.
|
quant_config: The GPTQ Marlin quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_kernel_backends_being_used: Set[str] = set()
|
||||||
|
|
||||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
del output_size
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
is_row_parallel = input_size != input_size_per_partition
|
is_row_parallel = input_size != input_size_per_partition
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||||
|
full_weight_shape=(input_size, output_size),
|
||||||
|
partition_weight_shape=\
|
||||||
|
(input_size_per_partition, output_size_per_partition),
|
||||||
|
weight_type=self.quant_config.quant_type,
|
||||||
|
act_type=params_dtype,
|
||||||
|
group_size=self.quant_config.group_size,
|
||||||
|
zero_points=False,
|
||||||
|
has_g_idx=self.quant_config.desc_act
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||||
|
|
||||||
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for GPTQMarlinLinearMethod",
|
||||||
|
kernel_type.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
# Normalize group_size
|
# Normalize group_size
|
||||||
if self.quant_config.group_size != -1:
|
if self.quant_config.group_size != -1:
|
||||||
group_size = self.quant_config.group_size
|
group_size = self.quant_config.group_size
|
||||||
else:
|
else:
|
||||||
group_size = input_size
|
group_size = input_size
|
||||||
|
|
||||||
verify_marlin_supports_shape(
|
|
||||||
output_size_per_partition=output_size_per_partition,
|
|
||||||
input_size_per_partition=input_size_per_partition,
|
|
||||||
input_size=input_size,
|
|
||||||
group_size=group_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine sharding
|
# Determine sharding
|
||||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||||
self.quant_config.group_size,
|
self.quant_config.group_size,
|
||||||
@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
layer.register_parameter("g_idx", g_idx)
|
layer.register_parameter("g_idx", g_idx)
|
||||||
layer.register_parameter("scales", scales)
|
layer.register_parameter("scales", scales)
|
||||||
layer.register_parameter("qzeros", qzeros)
|
layer.register_parameter("qzeros", qzeros)
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
|
||||||
layer.input_size = input_size
|
|
||||||
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
|
|
||||||
is_row_parallel)
|
|
||||||
|
|
||||||
# Checkpoints are serialized in AutoGPTQ format, which is different from the
|
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||||
# marlin format. This function is called after the weights are loaded.
|
w_q_param_name="qweight",
|
||||||
# Here, we handle the repacking, including the activation reordering case.
|
w_s_param_name="scales",
|
||||||
|
w_zp_param_name="qzeros",
|
||||||
|
w_gidx_param_name="g_idx")
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
device = layer.qweight.device
|
self.kernel.process_weights_after_loading(layer)
|
||||||
|
|
||||||
# required by torch.compile
|
|
||||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
|
||||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
|
||||||
|
|
||||||
# Allocate marlin workspace
|
|
||||||
layer.workspace = marlin_make_workspace(
|
|
||||||
layer.output_size_per_partition, device)
|
|
||||||
|
|
||||||
# Handle sorting for activation reordering if needed.
|
|
||||||
if self.quant_config.desc_act:
|
|
||||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
|
|
||||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
|
||||||
replace_tensor(layer, "g_idx", g_idx)
|
|
||||||
else:
|
|
||||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
|
||||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
|
||||||
|
|
||||||
# No zero-point
|
|
||||||
layer.zp = marlin_make_empty_g_idx(device)
|
|
||||||
|
|
||||||
# Repack weights from autogptq format to marlin format.
|
|
||||||
marlin_qweight = ops.gptq_marlin_repack(
|
|
||||||
layer.qweight,
|
|
||||||
perm=layer.g_idx_sort_indices,
|
|
||||||
size_k=layer.input_size_per_partition,
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
num_bits=self.quant_config.quant_type.size_bits,
|
|
||||||
)
|
|
||||||
replace_tensor(layer, "qweight", marlin_qweight)
|
|
||||||
|
|
||||||
# Permute scales from autogptq format to marlin format.
|
|
||||||
marlin_scales = marlin_permute_scales(
|
|
||||||
layer.scales,
|
|
||||||
size_k=(layer.input_size if self.quant_config.desc_act else
|
|
||||||
layer.input_size_per_partition),
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
group_size=self.quant_config.group_size,
|
|
||||||
)
|
|
||||||
replace_tensor(layer, "scales", marlin_scales)
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return apply_gptq_marlin_linear(
|
return self.kernel.apply_weights(layer, x, bias)
|
||||||
input=x,
|
|
||||||
weight=layer.qweight,
|
|
||||||
weight_scale=layer.scales,
|
|
||||||
weight_zp=layer.zp,
|
|
||||||
g_idx=layer.g_idx,
|
|
||||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
|
||||||
workspace=layer.workspace,
|
|
||||||
wtype=self.quant_config.quant_type,
|
|
||||||
output_size_per_partition=layer.output_size_per_partition,
|
|
||||||
input_size_per_partition=layer.input_size_per_partition,
|
|
||||||
is_k_full=layer.is_k_full,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||||
@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
w13_g_idx_sort_indices[e]]
|
w13_g_idx_sort_indices[e]]
|
||||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
||||||
w2_g_idx_sort_indices[e]]
|
w2_g_idx_sort_indices[e]]
|
||||||
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
|
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||||
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
|
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||||
replace_tensor(layer, "w13_g_idx_sort_indices",
|
replace_parameter(layer, "w13_g_idx_sort_indices",
|
||||||
w13_g_idx_sort_indices)
|
w13_g_idx_sort_indices)
|
||||||
replace_tensor(layer, "w2_g_idx_sort_indices",
|
replace_parameter(layer, "w2_g_idx_sort_indices",
|
||||||
w2_g_idx_sort_indices)
|
w2_g_idx_sort_indices)
|
||||||
else:
|
else:
|
||||||
# Reset g_idx related tensors
|
# Reset g_idx related tensors
|
||||||
num_experts = layer.w13_g_idx.shape[0]
|
num_experts = layer.w13_g_idx.shape[0]
|
||||||
@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_qweight.shape[2],
|
layer.w13_qweight.shape[2],
|
||||||
self.quant_config.quant_type.size_bits,
|
self.quant_config.quant_type.size_bits,
|
||||||
)
|
)
|
||||||
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
|
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||||
layer.w2_qweight,
|
layer.w2_qweight,
|
||||||
layer.w2_g_idx_sort_indices,
|
layer.w2_g_idx_sort_indices,
|
||||||
@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_qweight.shape[2],
|
layer.w2_qweight.shape[2],
|
||||||
self.quant_config.quant_type.size_bits,
|
self.quant_config.quant_type.size_bits,
|
||||||
)
|
)
|
||||||
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
|
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||||
# Repack scales
|
# Repack scales
|
||||||
marlin_w13_scales = marlin_moe_permute_scales(
|
marlin_w13_scales = marlin_moe_permute_scales(
|
||||||
s=layer.w13_scales,
|
s=layer.w13_scales,
|
||||||
@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
size_n=layer.w13_scales.shape[2],
|
size_n=layer.w13_scales.shape[2],
|
||||||
group_size=self.quant_config.group_size,
|
group_size=self.quant_config.group_size,
|
||||||
)
|
)
|
||||||
replace_tensor(layer, "w13_scales", marlin_w13_scales)
|
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||||
marlin_w2_scales = marlin_moe_permute_scales(
|
marlin_w2_scales = marlin_moe_permute_scales(
|
||||||
s=layer.w2_scales,
|
s=layer.w2_scales,
|
||||||
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
|
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
|
||||||
size_n=layer.w2_scales.shape[2],
|
size_n=layer.w2_scales.shape[2],
|
||||||
group_size=self.quant_config.group_size,
|
group_size=self.quant_config.group_size,
|
||||||
)
|
)
|
||||||
replace_tensor(layer, "w2_scales", marlin_w2_scales)
|
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
@ -0,0 +1,83 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.scalar_type import ScalarType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MPLinearLayerConfig:
|
||||||
|
full_weight_shape: Tuple[int, int] # [in, out]
|
||||||
|
partition_weight_shape: Tuple[int, int]
|
||||||
|
weight_type: ScalarType
|
||||||
|
act_type: torch.dtype
|
||||||
|
group_size: int
|
||||||
|
zero_points: bool
|
||||||
|
has_g_idx: bool
|
||||||
|
|
||||||
|
|
||||||
|
class MPLinearKernel(ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
c: MPLinearLayerConfig,
|
||||||
|
w_q_param_name: str,
|
||||||
|
w_s_param_name: str,
|
||||||
|
w_zp_param_name: Optional[str] = None,
|
||||||
|
w_gidx_param_name: Optional[str] = None) -> None:
|
||||||
|
assert self.can_implement(c)
|
||||||
|
self.config = c
|
||||||
|
self.w_q_name = w_q_param_name
|
||||||
|
self.w_s_name = w_s_param_name
|
||||||
|
self.w_zp_name = w_zp_param_name
|
||||||
|
self.w_gidx_name = w_gidx_param_name
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
|
||||||
|
fn: Callable) -> None:
|
||||||
|
if name is not None and getattr(layer, name, None) is not None:
|
||||||
|
|
||||||
|
old_param = getattr(layer, name)
|
||||||
|
new_param = fn(old_param)
|
||||||
|
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||||
|
# compatibility
|
||||||
|
replace_parameter(
|
||||||
|
layer, name,
|
||||||
|
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||||
|
|
||||||
|
def _get_weight_params(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> Tuple[torch.Tensor, # w_q
|
||||||
|
torch.Tensor, # w_s
|
||||||
|
Optional[torch.Tensor], # w_zp,
|
||||||
|
Optional[torch.Tensor] # w_gidx
|
||||||
|
]:
|
||||||
|
return (
|
||||||
|
getattr(layer, self.w_q_name),
|
||||||
|
getattr(layer, self.w_s_name),
|
||||||
|
getattr(layer, self.w_zp_name or "", None),
|
||||||
|
getattr(layer, self.w_gidx_name or "", None),
|
||||||
|
)
|
72
vllm/model_executor/layers/quantization/kernels/__init__.py
Normal file
72
vllm/model_executor/layers/quantization/kernels/__init__.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.machete import (
|
||||||
|
MacheteLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.marlin import (
|
||||||
|
MarlinLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
|
||||||
|
MPLinearKernel, MPLinearLayerConfig)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# in priority/performance order (when available)
|
||||||
|
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||||
|
MacheteLinearKernel,
|
||||||
|
MarlinLinearKernel,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def choose_mp_linear_kernel(
|
||||||
|
config: MPLinearLayerConfig,
|
||||||
|
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
|
||||||
|
"""
|
||||||
|
Choose an MPLinearKernel that can implement the given config for the given
|
||||||
|
compute capability. Attempts to choose the best kernel in terms of
|
||||||
|
performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||||
|
implemented.
|
||||||
|
compute_capability (Optional[int], optional): The compute capability of
|
||||||
|
the target device, if None uses `current_platform` to get the compute
|
||||||
|
capability. Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no kernel can implement the given config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[MPLinearKernel]: Chosen kernel.
|
||||||
|
"""
|
||||||
|
if compute_capability is None:
|
||||||
|
if current_platform is None:
|
||||||
|
raise ValueError("Cannot determine compute capability")
|
||||||
|
_cc = current_platform.get_device_capability()
|
||||||
|
compute_capability = _cc[0] * 10 + _cc[1]
|
||||||
|
|
||||||
|
failure_reasons = []
|
||||||
|
for kernel in _POSSIBLE_KERNELS:
|
||||||
|
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
||||||
|
.split(","):
|
||||||
|
failure_reasons.append(
|
||||||
|
f' {kernel.__name__} disabled by environment variable')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel.get_min_capability() > compute_capability:
|
||||||
|
failure_reasons.append(
|
||||||
|
f"{kernel.__name__} requires capability "
|
||||||
|
f"{kernel.get_min_capability()}, current compute capability "
|
||||||
|
f"is {compute_capability}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
can_implement, failure_reason = kernel.can_implement(config)
|
||||||
|
if can_implement:
|
||||||
|
return kernel
|
||||||
|
else:
|
||||||
|
failure_reasons.append(
|
||||||
|
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to find a kernel that can implement the "\
|
||||||
|
"WNA16 linear layer. Reasons: \n"
|
||||||
|
+ '\n'.join(failure_reasons))
|
118
vllm/model_executor/layers/quantization/kernels/machete.py
Normal file
118
vllm/model_executor/layers/quantization/kernels/machete.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||||
|
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
|
||||||
|
query_machete_supported_quant_types)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
pack_weights_into_int32, unpack_weights_into_int32)
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
permute_param_layout_)
|
||||||
|
|
||||||
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MacheteLinearKernel(MPLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 90
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
if c.has_g_idx and\
|
||||||
|
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||||
|
return False, "Act reordering currently not supported by Machete, "\
|
||||||
|
"when the input features are partitioned across "\
|
||||||
|
"devices"
|
||||||
|
|
||||||
|
if c.zero_points:
|
||||||
|
return False, "Zero points currently not supported by "\
|
||||||
|
" Compressed Tensors + Machete. (Kernel supports it"\
|
||||||
|
" but CompressedTensorsWNA16 does not so support has"\
|
||||||
|
" not been added to MacheteWNA16Kernel yet"
|
||||||
|
|
||||||
|
if c.weight_type not in query_machete_supported_quant_types(
|
||||||
|
c.zero_points):
|
||||||
|
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||||
|
"Machete, supported types are: "\
|
||||||
|
f"{query_machete_supported_quant_types(c.zero_points)}"
|
||||||
|
|
||||||
|
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
|
||||||
|
return False, f"Group size ({c.group_size}) not supported by "\
|
||||||
|
"Machete, supported group sizes are: "\
|
||||||
|
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
|
||||||
|
|
||||||
|
return check_machete_supports_shape(c.partition_weight_shape[0],
|
||||||
|
c.partition_weight_shape[1])
|
||||||
|
|
||||||
|
# note assumes that
|
||||||
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
c = self.config
|
||||||
|
|
||||||
|
if c.has_g_idx:
|
||||||
|
assert self.w_gidx_name is not None
|
||||||
|
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
|
||||||
|
.to(torch.int)
|
||||||
|
|
||||||
|
self.act_perm = lambda x: x[:, perm]
|
||||||
|
# use `ops.permute_cols` if possible
|
||||||
|
if c.act_type in [torch.float16, torch.bfloat16] \
|
||||||
|
and c.partition_weight_shape[0] % 8 == 0:
|
||||||
|
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||||
|
|
||||||
|
def transform_w_q(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
if c.has_g_idx:
|
||||||
|
x_unpacked = unpack_weights_into_int32(x.data,
|
||||||
|
c.weight_type,
|
||||||
|
packed_dim=0)
|
||||||
|
x_perm = x_unpacked[perm, :]
|
||||||
|
x.data = pack_weights_into_int32(x_perm,
|
||||||
|
c.weight_type,
|
||||||
|
packed_dim=0)
|
||||||
|
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||||
|
self.config.weight_type)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def transform_w_s(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||||
|
x.data = x.data.contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Repack weights and scales for Machete
|
||||||
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
c = self.config
|
||||||
|
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
x_2d = x.reshape(-1, x.shape[-1])
|
||||||
|
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||||
|
|
||||||
|
if c.has_g_idx:
|
||||||
|
x_2d = self.act_perm(x_2d)
|
||||||
|
|
||||||
|
output = ops.machete_gemm(a=x_2d,
|
||||||
|
b_q=w_q,
|
||||||
|
b_type=c.weight_type,
|
||||||
|
b_zeros=None,
|
||||||
|
b_scales=w_s,
|
||||||
|
b_group_size=c.group_size)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output.reshape(out_shape)
|
132
vllm/model_executor/layers/quantization/kernels/marlin.py
Normal file
132
vllm/model_executor/layers/quantization/kernels/marlin.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||||
|
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||||
|
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||||
|
query_marlin_supported_quant_types)
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
permute_param_layout_)
|
||||||
|
|
||||||
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinLinearKernel(MPLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
if c.zero_points:
|
||||||
|
return False, "Zero points currently not supported by "\
|
||||||
|
" MarlinLinearKernel. Will be added when AWQMarlin "\
|
||||||
|
"is migrated over to using MPLinearKernel backend"
|
||||||
|
|
||||||
|
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||||
|
if c.weight_type not in quant_types:
|
||||||
|
return False, f"Quant type ({c.weight_type}) not supported by"\
|
||||||
|
f" Marlin, supported types are: {quant_types}"
|
||||||
|
|
||||||
|
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||||
|
return False, f"Group size ({c.group_size}) not supported by "\
|
||||||
|
"Marlin, supported group sizes are: "\
|
||||||
|
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
|
||||||
|
|
||||||
|
return check_marlin_supports_shape(c.partition_weight_shape[0],
|
||||||
|
c.partition_weight_shape[1],
|
||||||
|
c.full_weight_shape[1],
|
||||||
|
c.group_size)
|
||||||
|
|
||||||
|
# note assumes that
|
||||||
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
device = getattr(layer, self.w_q_name).device
|
||||||
|
c = self.config
|
||||||
|
|
||||||
|
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
|
||||||
|
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||||
|
|
||||||
|
# Allocate marlin workspace.
|
||||||
|
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
||||||
|
device)
|
||||||
|
|
||||||
|
# Default names since marlin requires empty parameters for these,
|
||||||
|
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||||
|
if self.w_gidx_name is None:
|
||||||
|
self.w_gidx_name = "g_idx"
|
||||||
|
if self.w_zp_name is None:
|
||||||
|
self.w_zp_name = "w_zp"
|
||||||
|
|
||||||
|
if c.has_g_idx:
|
||||||
|
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||||
|
getattr(layer, self.w_gidx_name))
|
||||||
|
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||||
|
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||||
|
else:
|
||||||
|
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||||
|
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||||
|
|
||||||
|
if c.zero_points:
|
||||||
|
pass
|
||||||
|
# TODO (lucas): add the following when AWQMarlin is migrated over to
|
||||||
|
# using MPLinearKernel backend
|
||||||
|
# self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||||
|
# marlin_zero_points(
|
||||||
|
# x,
|
||||||
|
# size_k=c.partition_weight_shape[0],
|
||||||
|
# size_n=c.partition_weight_shape[1],
|
||||||
|
# num_bits=c.weight_type.size_bits))
|
||||||
|
else:
|
||||||
|
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||||
|
|
||||||
|
def transform_w_q(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
|
||||||
|
perm=layer.g_idx_sort_indices,
|
||||||
|
size_k=c.partition_weight_shape[0],
|
||||||
|
size_n=c.partition_weight_shape[1],
|
||||||
|
num_bits=c.weight_type.size_bits)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def transform_w_s(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||||
|
x.data = marlin_permute_scales(x.data.contiguous(),
|
||||||
|
size_k=c.partition_weight_shape[0],
|
||||||
|
size_n=c.partition_weight_shape[1],
|
||||||
|
group_size=c.group_size)
|
||||||
|
return x
|
||||||
|
|
||||||
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
c = self.config
|
||||||
|
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||||
|
# None for marlin
|
||||||
|
return apply_gptq_marlin_linear(
|
||||||
|
input=x,
|
||||||
|
weight=w_q,
|
||||||
|
weight_scale=w_s,
|
||||||
|
weight_zp=w_zp, # type: ignore
|
||||||
|
g_idx=w_gidx, # type: ignore
|
||||||
|
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||||
|
workspace=self.workspace,
|
||||||
|
wtype=c.weight_type,
|
||||||
|
input_size_per_partition=c.partition_weight_shape[0],
|
||||||
|
output_size_per_partition=c.partition_weight_shape[1],
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
|
bias=bias)
|
@ -0,0 +1,3 @@
|
|||||||
|
from .layer_utils import replace_parameter, update_tensor_inplace
|
||||||
|
|
||||||
|
__all__ = ['update_tensor_inplace', 'replace_parameter']
|
33
vllm/model_executor/layers/quantization/utils/layer_utils.py
Normal file
33
vllm/model_executor/layers/quantization/utils/layer_utils.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
|
||||||
|
assert dst.dtype == src.dtype, "Tensors must have the same dtype"
|
||||||
|
|
||||||
|
# update tensor shape and stride
|
||||||
|
dst.as_strided_(src.shape, src.stride())
|
||||||
|
|
||||||
|
# If not the same underlying storage move tensor data
|
||||||
|
if dst.data_ptr() != src.data_ptr():
|
||||||
|
dst.copy_(src)
|
||||||
|
del src
|
||||||
|
|
||||||
|
|
||||||
|
# Newly generated tensors need to replace existing tensors that are
|
||||||
|
# already registered as parameters by vLLM (and won't be freed)
|
||||||
|
def replace_parameter(mod: torch.nn.Module, name: str,
|
||||||
|
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
|
||||||
|
|
||||||
|
old = getattr(mod, name)
|
||||||
|
if old.dtype == new.dtype and \
|
||||||
|
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
|
||||||
|
# If we can just update in-place to avoid re-registering
|
||||||
|
# can be faster if the underlying storage is the same
|
||||||
|
update_tensor_inplace(old, new)
|
||||||
|
else:
|
||||||
|
# Fallback re-register parameter
|
||||||
|
if not isinstance(new, torch.nn.Parameter):
|
||||||
|
new = torch.nn.Parameter(new)
|
||||||
|
mod.register_parameter(name, torch.nn.Parameter(new))
|
@ -0,0 +1,30 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||||
|
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
|
||||||
|
|
||||||
|
|
||||||
|
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
|
||||||
|
if zero_points:
|
||||||
|
return [scalar_types.uint4, scalar_types.uint8]
|
||||||
|
else:
|
||||||
|
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||||
|
|
||||||
|
|
||||||
|
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
|
||||||
|
def check_machete_supports_shape(in_features: int, out_featrues: int) \
|
||||||
|
-> Tuple[bool, Optional[str]]:
|
||||||
|
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
|
||||||
|
return False, "Input features size must be divisible by "\
|
||||||
|
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
|
||||||
|
if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
|
||||||
|
return False, "Output features size must be divisible by "\
|
||||||
|
f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}"
|
||||||
|
return True, None
|
@ -120,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
|
|||||||
"with --quantization gptq.")
|
"with --quantization gptq.")
|
||||||
|
|
||||||
|
|
||||||
|
def check_marlin_supports_shape(output_size_per_partition: int,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
input_size: int, group_size: int) \
|
||||||
|
-> Tuple[bool, Optional[str]]:
|
||||||
|
try:
|
||||||
|
verify_marlin_supports_shape(output_size_per_partition,
|
||||||
|
input_size_per_partition, input_size,
|
||||||
|
group_size)
|
||||||
|
except ValueError as e:
|
||||||
|
return False, e.__str__()
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def marlin_make_workspace(output_size_per_partition: int,
|
def marlin_make_workspace(output_size_per_partition: int,
|
||||||
device: torch.device) -> torch.Tensor:
|
device: torch.device) -> torch.Tensor:
|
||||||
max_workspace_size = (output_size_per_partition //
|
max_workspace_size = (output_size_per_partition //
|
||||||
@ -148,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||||
|
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def marlin_sort_g_idx(
|
def marlin_sort_g_idx(
|
||||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||||
@ -240,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
|||||||
return marlin_zp
|
return marlin_zp
|
||||||
|
|
||||||
|
|
||||||
# Newly generated tensors need to replace existing tensors that are
|
|
||||||
# already registered as parameters by vLLM (and won't be freed)
|
|
||||||
def replace_tensor(layer: torch.nn.Module, name: str,
|
|
||||||
new_t: torch.Tensor) -> None:
|
|
||||||
# It is important to use resize_() here since it ensures
|
|
||||||
# the same buffer is reused
|
|
||||||
getattr(layer, name).resize_(new_t.shape)
|
|
||||||
getattr(layer, name).copy_(new_t)
|
|
||||||
del new_t
|
|
||||||
|
|
||||||
|
|
||||||
def apply_gptq_marlin_linear(
|
def apply_gptq_marlin_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def pack_weights_into_int32(w_q: torch.Tensor,
|
||||||
|
wtype: ScalarType,
|
||||||
|
packed_dim: int = 0):
|
||||||
|
# move dim to pack to the end
|
||||||
|
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||||
|
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||||
|
w_q_perm = w_q.permute(perm)
|
||||||
|
|
||||||
|
pack_factor = 32 // wtype.size_bits
|
||||||
|
mask = (1 << wtype.size_bits) - 1
|
||||||
|
|
||||||
|
new_shape_perm = list(w_q_perm.shape)
|
||||||
|
assert w_q_perm.shape[-1] % pack_factor == 0
|
||||||
|
new_shape_perm[-1] //= pack_factor
|
||||||
|
|
||||||
|
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
||||||
|
for i in range(pack_factor):
|
||||||
|
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
||||||
|
|
||||||
|
return res.permute(inv_perm)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_weights_into_int32(w_q: torch.Tensor,
|
||||||
|
wtype: ScalarType,
|
||||||
|
packed_dim: int = 0):
|
||||||
|
# move dim to pack to the end
|
||||||
|
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||||
|
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||||
|
w_q_perm = w_q.permute(perm)
|
||||||
|
|
||||||
|
pack_factor = 32 // wtype.size_bits
|
||||||
|
mask = (1 << wtype.size_bits) - 1
|
||||||
|
|
||||||
|
new_shape_perm = list(w_q_perm.shape)
|
||||||
|
new_shape_perm[-1] *= pack_factor
|
||||||
|
|
||||||
|
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
||||||
|
for i in range(pack_factor):
|
||||||
|
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
||||||
|
|
||||||
|
return res.permute(inv_perm)
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||||
# prefix: model.layers.0.self_attn.q_proj
|
# prefix: model.layers.0.self_attn.q_proj
|
||||||
# proj_name: q_proj
|
# proj_name: q_proj
|
||||||
|
@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
|
|||||||
marlin_tile_size=self.marlin_tile_size)
|
marlin_tile_size=self.marlin_tile_size)
|
||||||
|
|
||||||
|
|
||||||
|
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||||
|
output_dim: int, **kwargs) -> BasevLLMParameter:
|
||||||
|
"""
|
||||||
|
Permute a parameter's layout to the specified input and output dimensions,
|
||||||
|
useful for forcing the parameter into a known layout, for example, if I need
|
||||||
|
a packed (quantized) weight matrix to be in the layout
|
||||||
|
{input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
|
then I can call:
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||||
|
required, asserting if it cannot get it to the correct layout)
|
||||||
|
"""
|
||||||
|
|
||||||
|
curr_input_dim = getattr(param, "input_dim", None)
|
||||||
|
curr_output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
|
if curr_input_dim is None or curr_output_dim is None:
|
||||||
|
assert param.data.dim() == 2,\
|
||||||
|
"permute_param_layout_ only supports 2D parameters when either "\
|
||||||
|
"input_dim or output_dim is not set"
|
||||||
|
|
||||||
|
# if one of the dimensions is not set, set it to the opposite of the other
|
||||||
|
# we can only do this since we asserted the parameter is 2D above
|
||||||
|
if curr_input_dim is None:
|
||||||
|
assert curr_output_dim is not None,\
|
||||||
|
"either input or output dim must be set"
|
||||||
|
curr_input_dim = (curr_output_dim + 1) % 2
|
||||||
|
if curr_output_dim is None:
|
||||||
|
assert curr_input_dim is not None,\
|
||||||
|
"either input or output dim must be set"
|
||||||
|
curr_output_dim = (curr_input_dim + 1) % 2
|
||||||
|
|
||||||
|
# create permutation from the current layout to the layout with
|
||||||
|
# self.input_dim at input_dim and self.output_dim at output_dim preserving
|
||||||
|
# other dimensions
|
||||||
|
perm = [
|
||||||
|
i for i in range(param.data.dim())
|
||||||
|
if i not in [curr_input_dim, curr_output_dim]
|
||||||
|
]
|
||||||
|
perm.insert(input_dim, curr_input_dim)
|
||||||
|
perm.insert(output_dim, curr_output_dim)
|
||||||
|
|
||||||
|
if "packed_dim" in kwargs:
|
||||||
|
assert hasattr(param, "packed_dim") and\
|
||||||
|
param.packed_dim == perm[kwargs["packed_dim"]],\
|
||||||
|
"permute_param_layout_ currently doesn't support repacking"
|
||||||
|
|
||||||
|
param.data = param.data.permute(*perm)
|
||||||
|
if hasattr(param, "_input_dim"):
|
||||||
|
param._input_dim = input_dim
|
||||||
|
if hasattr(param, "_output_dim"):
|
||||||
|
param._output_dim = output_dim
|
||||||
|
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
|
||||||
|
param._packed_dim = kwargs["packed_dim"]
|
||||||
|
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||||
marlin_tile_size):
|
marlin_tile_size):
|
||||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user