AQLM CUDA support (#3287)

Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
James Fleming 2024-04-23 13:59:33 -04:00 committed by GitHub
parent 62b5166bd4
commit 2b7949c1c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1592 additions and 11 deletions

View File

@ -173,6 +173,7 @@ set(VLLM_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu")

View File

@ -0,0 +1,302 @@
import argparse
import os
import sys
from typing import Optional
import torch
import torch.nn.functional as F
from vllm._C import ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def torch_mult(
input: torch.Tensor, # [..., in_features]
weights: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
) -> torch.Tensor:
output = F.linear(input, weights)
return output
def dequant_out_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_weight_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_no_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias)
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
n = parts.sum().item()
device = torch.device('cuda:0')
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
count = 0
for index in range(16):
for i in range(8):
for book in range(nbooks):
codebooks[book, index, 0, i] = count * (10**book)
count += 1
print("codes shape", codes.shape)
for i in range(16):
for book in range(nbooks):
codes[0, i, book] = i
codes[0, -i, book] = i
weights = dequantize_weight(codes, codebooks, None)
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
print("weights shape:", weights.shape)
print("weights2 shape:", weights2.shape)
print("weights are:", weights)
print("weights2 are:", weights2)
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
print("last 128 weights are", weights[0, -128:])
print("last 128 weights2 are:", weights2[0, -128:])
def main():
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
parser.add_argument("--nbooks",
type=int,
default=1,
help="Number of codebooks (default: 1)")
parser.add_argument("--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)")
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)")
# Parse the arguments
args = parser.parse_args()
# Extract values
nbooks = args.nbooks
bits = args.bits
if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
return
# Otherwise, benchmark.
methods = [
ops.aqlm_gemm,
dequant_out_scale,
generic_dequantize_gemm,
optimized_dequantize_gemm,
dequant_weight_scale,
torch_mult,
dequant_no_scale,
]
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
print(f"writing benchmarks to file {filename}")
with open(filename, "w") as f:
sys.stdout = f
print('m | k | n | n parts', end='')
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
print('')
# These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
(4096, (11008, 11008)), (11008, (4096, )))
# reasonable ranges for m.
for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
128, 256, 512, 1024, 1536, 2048, 3072, 4096
]:
print(f'{m}', file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
methods)
sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
methods):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
num_calls = 100
# warmup.
for method in methods:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
for method in methods:
best_time_us = 1e20
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
kernel_dur_us = 1000 * kernel_dur_ms
if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us
print(f' | {kernel_dur_us:.0f}', end='')
print('')
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
nbooks: int, bits: int, method) -> float:
n = parts.sum().item()
device = torch.device('cuda:0')
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
# for comparison to just a pytorch mult.
weights = torch.randn((n, k), dtype=torch.float16, device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if method is torch_mult:
for i in range(num_calls):
torch_mult(input, weights, scales)
else:
for i in range(num_calls):
method(input, codes, codebooks, scales, parts, None)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())

View File

@ -86,6 +86,21 @@ void gelu_fast(
torch::Tensor& input);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
);
torch::Tensor aqlm_dequant(
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes
);
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,

View File

@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");

View File

@ -0,0 +1,712 @@
/*
* Modified by Neural Magic
* Adapted from https://github.com/Vahe1994/AQLM
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include <cstdlib>
namespace vllm {
namespace aqlm {
__global__ void Code1x16MatVec(
const int4* __restrict__ A,
const int4* __restrict__ B,
int4* __restrict__ C,
const int4* __restrict__ codebook,
const int prob_m,
const int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
__shared__ int4 sh_b[32 * 9];
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8)
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm volatile (
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]])
);
half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++)
res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2)
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code2x8MatVec(
const int4* __restrict__ A,
const int4* __restrict__ B,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
extern __shared__ int4 sh[];
int4* sh_b = sh;
int4* sh_code = sh_b + 32 * 9;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++)
sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8)
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++)
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2)
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code1x16Dequant(
const int4* __restrict__ A,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm volatile (
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]])
);
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
__global__ void Code2x8Dequant(
const int4* __restrict__ A,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
extern __shared__ int4 sh[];
int4* sh_code = sh;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++)
sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
const int THREAD_M = 16;
void code1x16_matvec_cuda(
const void* __restrict__ A,
const void* __restrict__ B,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
(const int4*) A,
(const int4*) B,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
}
void code2x8_matvec_cuda(
const void* __restrict__ A,
const void* __restrict__ B,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaFuncSetAttribute(
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*) A,
(const int4*) B,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
}
void code1x16_dequant_cuda(
const void* __restrict__ A,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*) A,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
codebook_stride // as int4.
);
}
// Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda(
const void* __restrict__ A,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaFuncSetAttribute(
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
);
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*) A,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
}
int codebook_stride(const torch::Tensor& codebooks)
{
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
}
void code1x16_matvec(
const torch::Tensor& A,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride(codebook)
);
}
torch::Tensor code1x16_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features},
torch::TensorOptions()
.dtype(input.dtype())
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code1x16_matvec(
codes.squeeze(2),
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
void code2x8_matvec(
const torch::Tensor& A,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
2 * codebook_stride(codebook)
);
}
torch::Tensor code2x8_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias
) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features},
torch::TensorOptions()
.dtype(input.dtype())
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code2x8_matvec(
codes.squeeze(2),
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
// Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
{
int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x;
int i = 0;
int last = 0;
assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
{
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
last = *cumulative_size;
}
// fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size)
{
*cumulative_size = last*10;
}
return cumulative_sizes;
}
} // namespace aqlm
} // namespace vllm
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
)
{
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16))
{
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
}
if (nbooks == 2 && entries == (1 << 8))
{
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
return {};
}
torch::Tensor aqlm_dequant(
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes
)
{
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
int rows = codes.size(1);
int cols = codes.size(0);
auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0);
assert(out_features = codebook_partition_sizes.sum().item<int>());
auto weights = torch::empty({out_features, in_features},
torch::TensorOptions()
.dtype(codebooks.dtype())
.device(codebooks.device())
);
if (nbooks == 1 && entries == (1 << 16))
{
vllm::aqlm::code1x16_dequant_cuda(
codes.data_ptr(),
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
// weights *= scales.index({"...", 0, 0});
return weights;
}
if (nbooks == 2 && entries == (1 << 8))
{
vllm::aqlm::code2x8_dequant_cuda(
codes.data_ptr(),
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
// weights *= scales.index({"...", 0, 0});
return weights;
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
return {};
}

46
examples/aqlm_example.py Normal file
View File

@ -0,0 +1,46 @@
import argparse
from vllm import LLM, SamplingParams
def main():
parser = argparse.ArgumentParser(description='AQLM examples')
parser.add_argument('--model',
'-m',
type=str,
default=None,
help='model path, as for HF')
parser.add_argument('--choice',
'-c',
type=int,
default=0,
help='known good models by index, [0-4]')
parser.add_argument('--tensor_parallel_size',
'-t',
type=int,
default=1,
help='tensor parallel size')
args = parser.parse_args()
models = [
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
"ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
"ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
"BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
]
model = LLM(args.model if args.model is not None else models[args.choice],
tensor_parallel_size=args.tensor_parallel_size)
sampling_params = SamplingParams(max_tokens=100, temperature=0)
outputs = model.generate("Hello my name is",
sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
if __name__ == '__main__':
main()

95
tests/models/test_aqlm.py Normal file
View File

@ -0,0 +1,95 @@
"""Compare the outputs of a AQLM model between vLLM and HF Transformers
Run `pytest tests/models/test_aqlm.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
aqlm_not_supported = (capability <
QUANTIZATION_METHODS["aqlm"].get_min_capability())
# In this test we hardcode prompts and generations for the model so we don't
# need to require the AQLM package as a dependency
example_prompts = [
'vLLM is a high-throughput and memory-efficient inference and serving '
'engine for LLMs.\n',
'Briefly describe the major milestones in the development of artificial '
'intelligence from 1950 to 2020.\n',
'Compare and contrast artificial intelligence with human intelligence in '
'terms of processing information.\n',
'Describe the basic components of a neural network and how it can be '
'trained.\n',
'Write a short story about a robot that dreams for the first time.\n',
'Analyze the impact of the COVID-19 pandemic on global economic structures '
'and future business models.\n',
'Explain the cultural significance of the Mona Lisa painting, and how its '
'perception might vary in Western versus Eastern societies.\n',
"Translate the following English sentence into Japanese, French, and "
"Swahili: 'The early bird catches the worm.'\n"
]
# These ground truth generations were generated using `transformers==4.38.1
# aqlm==1.1.0 torch==2.2.0`
# and the below code:
# ```python
# from transformers import AutoTokenizer, AutoModelForCausalLM
# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
# quantized_model = AutoModelForCausalLM.from_pretrained(model_id,
# torch_dtype="auto", device_map="cuda").cuda()
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# outputs = []
# for prompt in example_prompts:
# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32)
# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:]))
# print(outputs)
# ```
ground_truth_generations = [
'\n### Features\n\n- **High-throughput**: v',
'The major milestones in the development of artificial intelligence from '
'195',
'Compare and contrast artificial intelligence with human intelligence in '
'terms of processing information. The',
'Explain the difference between supervised and unsupervised learning.'
'\nExplain',
'Write a short story about a robot that dreams for the first time. The',
'Analyze the impact of the COVID-19 pandemic on global economic',
'The Mona Lisa is a painting by Leonardo da Vinci, and it',
'The early bird catches the worm.\nThe early bird catches the'
]
@pytest.mark.skipif(aqlm_not_supported,
reason="AQLM is not supported on this GPU type.")
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("num_logprobs", [1])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
# loop through the prompts to compare against the ground truth generations
for prompt_idx in range(len(example_prompts)):
vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[
prompt_idx]
print("Prompt: ", repr(example_prompts[prompt_idx]))
print("Reference output:", repr(ground_truth_generations[prompt_idx]))
print("Output output: ", repr(vllm_output_str))
assert vllm_output_str == ground_truth_generations[prompt_idx]

View File

@ -31,7 +31,7 @@ class LinearMethodBase(ABC):
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
@ -70,9 +70,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
@ -127,7 +128,7 @@ class ReplicatedLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_method.create_weights(self, self.input_size,
self.output_size, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
@ -161,6 +162,8 @@ class ColumnParallelLinear(torch.nn.Module):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
def __init__(
@ -172,6 +175,7 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
output_sizes: Optional[List[int]] = None,
):
super().__init__()
@ -188,10 +192,12 @@ class ColumnParallelLinear(torch.nn.Module):
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
if output_sizes is None:
output_sizes = [output_size]
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.input_size,
self.output_size_per_partition,
[x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
@ -268,14 +274,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method)
skip_bias_add, params_dtype, linear_method,
self.output_sizes)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
is_metadata = getattr(param, "is_metadata", False)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
@ -328,6 +337,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
@ -393,8 +407,14 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
output_sizes = [
self.num_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method)
params_dtype, linear_method, output_sizes)
def weight_loader(self,
param: Parameter,
@ -402,6 +422,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
is_metadata = getattr(param, "is_metadata", False)
if loaded_shard_id is None:
# Loaded weight is already packed.
@ -469,6 +490,12 @@ class QKVParallelLinear(ColumnParallelLinear):
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
@ -536,7 +563,7 @@ class RowParallelLinear(torch.nn.Module):
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.input_size_per_partition,
self.output_size,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,

View File

@ -1,5 +1,6 @@
from typing import Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -9,6 +10,7 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": FP8Config,
"gptq": GPTQConfig,

View File

@ -0,0 +1,373 @@
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
import math
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)
def dequantize_weight(codes: torch.Tensor,
codebooks: torch.Tensor,
scales: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape
[*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code,
[num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be
broadcastble with
[*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape
[*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = \
codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size,
device=codes.device) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets,
codebooks.flatten(0, 1).flatten(-2, -1),
mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
# * in_group_size]
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) +
[num_out_groups, num_in_groups, out_group_size, in_group_size])
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
scales)
return reconstructed_weight_groupwise.swapaxes(
-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
def dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
dequantized_weight = dequantize_weight(
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
# Generic dequantization, slow but flexible.
def generic_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
num_outputs = len(output_partition_sizes)
# break the inputs and codebooks apart then combine the outputs.
# Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
# multiply at the end.
num_codebooks = codebooks.shape[0] // num_outputs
assert (scales.shape[0] == codes.shape[0])
assert (sum(output_partition_sizes) == scales.shape[0])
output_offset = 0
codebooks_offset = 0
for output_size in output_partition_sizes:
shard_output = dequantize_gemm(
input, codes.narrow(0, output_offset, output_size),
codebooks.narrow(0, codebooks_offset, num_codebooks),
scales.narrow(0, output_offset, output_size), None
if bias is None else bias.narrow(0, output_offset, output_size))
output_slice = output.narrow(-1, output_offset, output_size)
assert (output_slice.shape == shard_output.shape)
output_slice.copy_(shard_output)
output_offset += output_size
codebooks_offset += num_codebooks
return output
# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
# at 6 and 9 times faster than the generic version above, respectively.
def optimized_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
# scaling the output is fastest, so we do that when possible.
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
class AQLMConfig(QuantizationConfig):
"""Config class for AQLM.
Reference: https://github.com/Vahe1994/AQLM
"""
def __init__(
self,
in_group_size: int,
nbits_per_codebook: int,
num_codebooks: int,
out_group_size: int,
) -> None:
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks
self.out_group_size = out_group_size
# out_group_size > 1 is untested, and probably won't work as-is.
assert (self.out_group_size == 1)
self.pack_factor = (self.in_group_size * self.out_group_size)
def __repr__(self) -> str:
return (f"AQLMConfig(in_group_size={self.in_group_size}, "
f"nbits_per_codebook={self.nbits_per_codebook}, "
f"num_codebooks={self.num_codebooks}, "
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> str:
return "aqlm"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
out_group_size = cls.get_from_keys(config, ["out_group_size"])
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
def get_linear_method(self) -> "AQLMLinearMethod":
return AQLMLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM.
Args:
quant_config: The AQLM quantization config.
"""
def __init__(self, quant_config: AQLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del output_size # Unused.
del input_size # Unused.
if params_dtype != torch.half:
raise ValueError("Only half is currently supported by aqlm")
if input_size_per_partition % self.quant_config.in_group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.out_group_size != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
codes = Parameter(
torch.empty(
# There could actually be two pack factors, one along input and
# one along output, but we don't currently support
# out_group_size, and only the one along output needs to be
# marked with "packed_dim" in order for QKVLinear to work.
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
self.quant_config.num_codebooks,
dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
),
requires_grad=False,
)
set_weight_attrs(
codes,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
codebooks = Parameter(
torch.empty(
self.quant_config.num_codebooks * len(output_partition_sizes),
2**self.quant_config.nbits_per_codebook,
self.quant_config.out_group_size,
self.quant_config.in_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
codebooks,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata":
True,
"output_partition_sizes":
torch.tensor(output_partition_sizes, device='cpu'),
},
)
scales = Parameter(
torch.empty(
(
output_size_per_partition //
self.quant_config.out_group_size,
1,
1,
1,
),
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 0,
"packed_dim": 0,
"pack_factor": self.quant_config.out_group_size
},
)
layer.register_parameter("codes", codes)
set_weight_attrs(codes, extra_weight_attrs)
layer.register_parameter("codebooks", codebooks)
set_weight_attrs(codebooks, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
codebooks = layer.codebooks
codes = layer.codes
scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
None)
nbooks = codes.shape[2]
ingroups = codebooks.shape[3]
outgroups = codebooks.shape[2]
bits = codebooks.shape[1]
# We support these formats with dedicated gemm and decompression
# kernels.
if ingroups == 8 and outgroups == 1 and (
(bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
# thresholds determined by timings on an A6000, one GPU
use_gemv = math.prod(x.shape[:-1]) <= 6
return ops.aqlm_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
) if use_gemv else optimized_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)
# fall back all unoptimized formats
return generic_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)

View File

@ -81,7 +81,7 @@ class AWQLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
@ -89,6 +89,8 @@ class AWQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "

View File

@ -91,7 +91,7 @@ class GPTQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@ -103,6 +103,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(

View File

@ -93,7 +93,7 @@ class MarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@ -106,6 +106,7 @@ class MarlinLinearMethod(LinearMethodBase):
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "

View File

@ -70,7 +70,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
@ -78,6 +78,8 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,