Support W4A8 quantization for vllm (#5218)
This commit is contained in:
parent
c0644cf9ce
commit
6512937de1
11
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
11
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||||
|
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.409
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.406
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
@ -7,3 +7,4 @@ Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
|||||||
Minitron-4B-Base.yaml
|
Minitron-4B-Base.yaml
|
||||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||||
|
Meta-Llama-3-8B-QQQ.yaml
|
||||||
|
@ -170,6 +170,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
|
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||||
|
@ -115,6 +115,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b_q_weight,
|
||||||
|
torch::Tensor const& s_tok,
|
||||||
|
torch::Tensor const& s_ch,
|
||||||
|
torch::Tensor const& s_group,
|
||||||
|
torch::Tensor& workspace, int64_t size_m,
|
||||||
|
int64_t size_n, int64_t size_k);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
32
csrc/quantization/marlin/dense/common/base.h
Normal file
32
csrc/quantization/marlin/dense/common/base.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
* Modified by HandH1998
|
||||||
|
* Modified by Neural Magic
|
||||||
|
* Copyright (C) Marlin.2024 Elias Frantar
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||||
|
// for instance as inputs to tensor core operations. Consequently, all
|
||||||
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
|
// this.
|
||||||
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
|
T elems[n];
|
||||||
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
|
};
|
89
csrc/quantization/marlin/dense/common/mem.h
Normal file
89
csrc/quantization/marlin/dense/common/mem.h
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
/*
|
||||||
|
* Modified by HandH1998
|
||||||
|
* Modified by Neural Magic
|
||||||
|
* Copyright (C) Marlin.2024 Elias Frantar
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||||
|
// predication to handle batchsizes that are not multiples of 16.
|
||||||
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||||
|
bool pred = true) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asynchronous global->shared copy
|
||||||
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async copy fence.
|
||||||
|
__device__ inline void cp_async_fence() {
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||||
|
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
int state = -1;
|
||||||
|
do
|
||||||
|
// Guarantee that subsequent writes by this threadblock will be visible
|
||||||
|
// globally.
|
||||||
|
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||||
|
: "=r"(state)
|
||||||
|
: "l"(lock));
|
||||||
|
while (state != count);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release barrier and increment visitation count.
|
||||||
|
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
if (reset) {
|
||||||
|
lock[0] = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int val = 1;
|
||||||
|
// Make sure that all writes since acquiring this barrier are visible
|
||||||
|
// globally, while releasing the barrier.
|
||||||
|
asm volatile("fence.acq_rel.gpu;\n");
|
||||||
|
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||||
|
:
|
||||||
|
: "l"(lock), "r"(val));
|
||||||
|
}
|
||||||
|
}
|
@ -25,6 +25,12 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "common/base.h"
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
#include "common/mem.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::string str(T x) {
|
inline std::string str(T x) {
|
||||||
return std::to_string(x);
|
return std::to_string(x);
|
||||||
@ -32,23 +38,9 @@ inline std::string str(T x) {
|
|||||||
|
|
||||||
namespace marlin_dense {
|
namespace marlin_dense {
|
||||||
|
|
||||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
|
||||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
|
||||||
// for instance as inputs to tensor core operations. Consequently, all
|
|
||||||
// corresponding index accesses must be compile-time constants, which is why we
|
|
||||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
|
||||||
// this.
|
|
||||||
template <typename T, int n>
|
|
||||||
struct Vec {
|
|
||||||
T elems[n];
|
|
||||||
__device__ T& operator[](int i) { return elems[i]; }
|
|
||||||
};
|
|
||||||
|
|
||||||
using I4 = Vec<int, 4>;
|
using I4 = Vec<int, 4>;
|
||||||
|
|
||||||
// Matrix fragments for tensor core instructions; their precise layout is
|
// Matrix fragments for tensor core instructions; their precise layout is
|
||||||
// documented here:
|
// documented here:
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||||
@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>;
|
|||||||
using FragC = Vec<float, 4>;
|
using FragC = Vec<float, 4>;
|
||||||
using FragS = Vec<half2, 1>; // quantization scales
|
using FragS = Vec<half2, 1>; // quantization scales
|
||||||
|
|
||||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
|
||||||
// predication to handle batchsizes that are not multiples of 16.
|
|
||||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
|
||||||
bool pred = true) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" .reg .pred p;\n"
|
|
||||||
" setp.ne.b32 p, %0, 0;\n"
|
|
||||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
|
||||||
"}\n" ::"r"((int)pred),
|
|
||||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Asynchronous global->shared copy
|
|
||||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
|
||||||
const int BYTES = 16;
|
|
||||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
|
||||||
"}\n" ::"r"(smem),
|
|
||||||
"l"(glob_ptr), "n"(BYTES));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Async copy fence.
|
|
||||||
__device__ inline void cp_async_fence() {
|
|
||||||
asm volatile("cp.async.commit_group;\n" ::);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait until at most `n` async copy stages are still pending.
|
|
||||||
template <int n>
|
|
||||||
__device__ inline void cp_async_wait() {
|
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
|
||||||
}
|
|
||||||
|
|
||||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||||
// output/accumulation.
|
// output/accumulation.
|
||||||
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
|
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
|
||||||
@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
|||||||
frag_b[1] = __hmul2(frag_b[1], s);
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
|
||||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
int state = -1;
|
|
||||||
do
|
|
||||||
// Guarantee that subsequent writes by this threadblock will be visible
|
|
||||||
// globally.
|
|
||||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
|
||||||
: "=r"(state)
|
|
||||||
: "l"(lock));
|
|
||||||
while (state != count);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release barrier and increment visitation count.
|
|
||||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
if (reset) {
|
|
||||||
lock[0] = 0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int val = 1;
|
|
||||||
// Make sure that all writes since acquiring this barrier are visible
|
|
||||||
// globally, while releasing the barrier.
|
|
||||||
asm volatile("fence.acq_rel.gpu;\n");
|
|
||||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
|
||||||
:
|
|
||||||
: "l"(lock), "r"(val));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
|
1243
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
Normal file
1243
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||||
|
|
||||||
|
// marlin_qqq_gemm for QQQ.
|
||||||
|
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
|
||||||
|
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
// quantization.
|
// quantization.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
@ -10,6 +10,9 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.model_executor.layers.quantization.qqq import (
|
||||||
|
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
|
||||||
|
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
|
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
|
||||||
@ -21,6 +24,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|||||||
marlin_weights)
|
marlin_weights)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||||
marlin_24_quantize)
|
marlin_24_quantize)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
||||||
|
marlin_qqq_quantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
|
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
|
||||||
sort_weights)
|
sort_weights)
|
||||||
@ -425,3 +430,64 @@ def test_awq_marlin_gemm(
|
|||||||
print("max_diff = {}".format(max_diff))
|
print("max_diff = {}".format(max_diff))
|
||||||
|
|
||||||
assert max_diff < 0.04
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||||
|
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
def test_marlin_qqq_gemm(
|
||||||
|
k_chunk,
|
||||||
|
n_chunk,
|
||||||
|
num_bits,
|
||||||
|
group_size,
|
||||||
|
mnk_factors,
|
||||||
|
):
|
||||||
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
|
size_m = m_factor
|
||||||
|
size_k = k_chunk * k_factor
|
||||||
|
size_n = n_chunk * n_factor
|
||||||
|
|
||||||
|
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||||
|
print(f"groupsize = {group_size}")
|
||||||
|
|
||||||
|
a_input = rand_data((size_m, size_k))
|
||||||
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
|
||||||
|
# Quantize activations
|
||||||
|
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
|
||||||
|
torch.float)
|
||||||
|
q_a = (a_input / s_a).round().clamp(int8_traits.min,
|
||||||
|
int8_traits.max).to(torch.int8)
|
||||||
|
|
||||||
|
# Quantize weights
|
||||||
|
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
|
||||||
|
marlin_qqq_quantize(b_weight, num_bits, group_size)
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
|
||||||
|
MARLIN_QQQ_MAX_PARALLEL)
|
||||||
|
|
||||||
|
output = ops.marlin_qqq_gemm(
|
||||||
|
q_a,
|
||||||
|
marlin_qqq_q_w,
|
||||||
|
s_a,
|
||||||
|
marlin_qqq_s_channel,
|
||||||
|
marlin_qqq_s_group,
|
||||||
|
workspace.scratch,
|
||||||
|
a_input.shape[0],
|
||||||
|
b_weight.shape[1],
|
||||||
|
a_input.shape[1],
|
||||||
|
)
|
||||||
|
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
max_diff = compute_max_diff(output, output_ref)
|
||||||
|
print("max_diff = {}".format(max_diff))
|
||||||
|
|
||||||
|
assert max_diff < 0.04
|
||||||
|
@ -389,6 +389,15 @@ def scaled_int8_quant(
|
|||||||
return output, input_scales
|
return output, input_scales
|
||||||
|
|
||||||
|
|
||||||
|
# qqq ops
|
||||||
|
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
|
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
||||||
|
s_group: torch.Tensor, workspace: torch.Tensor,
|
||||||
|
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
||||||
|
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
|
||||||
|
workspace, size_m, size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
# moe
|
# moe
|
||||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||||
block_size: int, sorted_token_ids: torch.Tensor,
|
block_size: int, sorted_token_ids: torch.Tensor,
|
||||||
|
@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQMarlin24Config)
|
GPTQMarlin24Config)
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"squeezellm": SqueezeLLMConfig,
|
"squeezellm": SqueezeLLMConfig,
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
|
"qqq": QQQConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
285
vllm/model_executor/layers/quantization/qqq.py
Normal file
285
vllm/model_executor/layers/quantization/qqq.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
MARLIN_QQQ_TILE = 16
|
||||||
|
MARLIN_QQQ_MIN_THREAD_N = 64
|
||||||
|
MARLIN_QQQ_MIN_THREAD_K = 128
|
||||||
|
MARLIN_QQQ_MAX_PARALLEL = 16
|
||||||
|
|
||||||
|
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
||||||
|
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||||
|
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
||||||
|
|
||||||
|
|
||||||
|
class QQQConfig(QuantizationConfig):
|
||||||
|
"""Config class for QQQ
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/pdf/2406.09904
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
is_sym: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.is_sym = is_sym
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS:
|
||||||
|
raise ValueError(
|
||||||
|
f"QQQ does not support weight_bits = {self.weight_bits}. "
|
||||||
|
f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} "
|
||||||
|
"are supported.")
|
||||||
|
if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES:
|
||||||
|
raise ValueError(
|
||||||
|
f"QQQ does not support group_size = {self.group_size}. "
|
||||||
|
f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} "
|
||||||
|
"are supported.")
|
||||||
|
if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM:
|
||||||
|
raise ValueError(
|
||||||
|
f"QQQ does not support is_sym = {self.is_sym}. "
|
||||||
|
f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.")
|
||||||
|
|
||||||
|
# 4 Bits packed into 32 bit datatype.
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
# Tile size used by QQQ kernels.
|
||||||
|
self.tile_size = MARLIN_QQQ_TILE
|
||||||
|
|
||||||
|
# Min out_features dim
|
||||||
|
self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N
|
||||||
|
|
||||||
|
# Min in_features dim
|
||||||
|
self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K
|
||||||
|
|
||||||
|
# Max parallel problems to solve at once (improves large
|
||||||
|
# batch performance)
|
||||||
|
self.max_parallel = MARLIN_QQQ_MAX_PARALLEL
|
||||||
|
|
||||||
|
# Permutation length used by the QQQ kernels.
|
||||||
|
self.perm_len = 1024
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "QQQConfig(weight_bits={}, group_size={})".format(
|
||||||
|
self.weight_bits, self.group_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "qqq"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
"""List of filenames to search for in the model directory."""
|
||||||
|
return [
|
||||||
|
"quant_config.json",
|
||||||
|
"quantize_config.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
|
return cls(weight_bits, group_size)
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
prefix: str) -> Optional["QQQLinearMethod"]:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return QQQLinearMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class QQQLinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for QQQ.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The QQQ quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: QQQConfig):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
if params_dtype != torch.float16:
|
||||||
|
raise ValueError(
|
||||||
|
f"The params dtype must be float16, but got {params_dtype}")
|
||||||
|
|
||||||
|
# Validate output_size_per_partition
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = "
|
||||||
|
f"{output_size_per_partition} is not divisible by "
|
||||||
|
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
||||||
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = "
|
||||||
|
f"{output_size_per_partition} is not divisible by "
|
||||||
|
f"pack_factor = {self.quant_config.pack_factor}.")
|
||||||
|
|
||||||
|
# Validate input_size_per_partition
|
||||||
|
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight input_size_per_partition = "
|
||||||
|
f"{input_size_per_partition} is not divisible by "
|
||||||
|
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
||||||
|
if (self.quant_config.group_size != -1 and
|
||||||
|
input_size_per_partition % self.quant_config.group_size != 0):
|
||||||
|
raise ValueError(f"Weight input_size_per_partition = "
|
||||||
|
f"{input_size_per_partition} is not divisible by "
|
||||||
|
f"group_size = {self.quant_config.group_size}.")
|
||||||
|
|
||||||
|
# Check that we have at least 4 tiles horizontally in the shard
|
||||||
|
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||||
|
self.quant_config.tile_size**2)
|
||||||
|
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Each permutation group must reside on the same gpu")
|
||||||
|
|
||||||
|
# Quantized 4Bit weights packed into Int32.
|
||||||
|
qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.quant_config.tile_size,
|
||||||
|
output_size_per_partition * self.quant_config.tile_size //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qweight,
|
||||||
|
{
|
||||||
|
"input_dim": 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
"marlin_tile_size": self.quant_config.tile_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
s_channel = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
1,
|
||||||
|
output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
s_channel,
|
||||||
|
{
|
||||||
|
"input_dim": None,
|
||||||
|
"output_dim": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quant_config.group_size == -1:
|
||||||
|
s_group = Parameter(
|
||||||
|
torch.tensor(
|
||||||
|
[],
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.half,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
s_group = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.quant_config.group_size,
|
||||||
|
output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.half,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
s_group,
|
||||||
|
{
|
||||||
|
"input_dim": None if self.quant_config.group_size == -1 else 0,
|
||||||
|
"output_dim":
|
||||||
|
None if self.quant_config.group_size == -1 else 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate workspace (Used for internal locking mechanism)
|
||||||
|
max_workspace_size = (
|
||||||
|
output_size_per_partition //
|
||||||
|
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||||
|
workspace = Parameter(torch.zeros(max_workspace_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("B", qweight)
|
||||||
|
set_weight_attrs(qweight, extra_weight_attrs)
|
||||||
|
layer.register_parameter("s_channel", s_channel)
|
||||||
|
set_weight_attrs(s_channel, extra_weight_attrs)
|
||||||
|
layer.register_parameter("s_group", s_group)
|
||||||
|
set_weight_attrs(s_group, extra_weight_attrs)
|
||||||
|
layer.register_parameter("workspace", workspace)
|
||||||
|
set_weight_attrs(workspace, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qweight = layer.B
|
||||||
|
s_ch = layer.s_channel
|
||||||
|
s_group = layer.s_group
|
||||||
|
workspace = layer.workspace
|
||||||
|
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
size_m = x_2d.shape[0]
|
||||||
|
size_k = x_2d.shape[1]
|
||||||
|
size_n = s_ch.shape[1]
|
||||||
|
|
||||||
|
x_int8, s_tok = ops.scaled_int8_quant(x_2d)
|
||||||
|
|
||||||
|
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
|
||||||
|
workspace, size_m, size_n, size_k)
|
||||||
|
|
||||||
|
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output
|
@ -0,0 +1,125 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .marlin_utils_test import marlin_permute_weights
|
||||||
|
from .quant_utils import get_pack_factor, qqq_quantize_weights
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
||||||
|
# Permute
|
||||||
|
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||||
|
|
||||||
|
# Pack
|
||||||
|
pack_factor = get_pack_factor(num_bits)
|
||||||
|
orig_device = q_w.device
|
||||||
|
|
||||||
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||||
|
|
||||||
|
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||||
|
dtype=numpy.uint32)
|
||||||
|
if group_size == size_k:
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
|
||||||
|
else:
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||||
|
|
||||||
|
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
||||||
|
|
||||||
|
return q_packed
|
||||||
|
|
||||||
|
|
||||||
|
def get_qqq_scale_perms():
|
||||||
|
scale_perm: List[int] = []
|
||||||
|
for i in range(8):
|
||||||
|
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||||
|
scale_perm_single: List[int] = []
|
||||||
|
for i in range(4):
|
||||||
|
scale_perm_single.extend(
|
||||||
|
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||||
|
return scale_perm, scale_perm_single
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
||||||
|
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
||||||
|
perm_list: List[int] = []
|
||||||
|
for i in range(32):
|
||||||
|
perm1: List[int] = []
|
||||||
|
col = i // 4
|
||||||
|
for block in [0, 1]:
|
||||||
|
for row in [
|
||||||
|
4 * (i % 4),
|
||||||
|
4 * (i % 4) + 1,
|
||||||
|
4 * (i % 4) + 2,
|
||||||
|
4 * (i % 4) + 3,
|
||||||
|
]:
|
||||||
|
perm1.append(16 * row + col + 8 * block)
|
||||||
|
for j in range(4):
|
||||||
|
perm_list.extend([p + 256 * j for p in perm1])
|
||||||
|
|
||||||
|
perm = numpy.array(perm_list)
|
||||||
|
|
||||||
|
assert quant_type in ["per-channel",
|
||||||
|
"per-group"], "not supported quantization type"
|
||||||
|
if num_bits == 4:
|
||||||
|
if quant_type == "per-channel":
|
||||||
|
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||||
|
else:
|
||||||
|
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||||
|
else:
|
||||||
|
raise Exception("num_bits must be 4, got {}".format(num_bits))
|
||||||
|
|
||||||
|
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||||
|
perm = torch.from_numpy(perm)
|
||||||
|
return perm
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
|
||||||
|
scale_perm, scale_perm_single = get_qqq_scale_perms()
|
||||||
|
if group_size < size_k and group_size != -1:
|
||||||
|
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||||
|
s_channel = s_channel.reshape(
|
||||||
|
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||||
|
s_group = s_group.reshape((-1, size_n)).contiguous()
|
||||||
|
else:
|
||||||
|
s_channel = s_channel.reshape(
|
||||||
|
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||||
|
s_channel = s_channel.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
|
return s_group, s_channel
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_qqq_quantize(
|
||||||
|
w: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
):
|
||||||
|
size_k, size_n = w.shape
|
||||||
|
|
||||||
|
# Normalize group_size
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = size_k
|
||||||
|
assert group_size <= size_k
|
||||||
|
quant_type = "per-channel" if group_size == size_k else "per-group"
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
|
||||||
|
w, num_bits, group_size)
|
||||||
|
|
||||||
|
# Reformat to marlin_qqq
|
||||||
|
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
|
||||||
|
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
|
||||||
|
weight_perm, group_size)
|
||||||
|
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
|
||||||
|
s_group, s_channel, size_k, size_n, group_size)
|
||||||
|
|
||||||
|
# Create result
|
||||||
|
res_list = [
|
||||||
|
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
|
||||||
|
]
|
||||||
|
for i in range(len(res_list)):
|
||||||
|
res_list[i] = res_list[i].to(w.device)
|
||||||
|
|
||||||
|
return res_list
|
@ -205,6 +205,88 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# QQQ employs different quant schemes for per-group and
|
||||||
|
# per-channel quantization.
|
||||||
|
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
||||||
|
orig_device = w.device
|
||||||
|
size_k, size_n = w.shape
|
||||||
|
|
||||||
|
assert w.is_floating_point(), "w must be float"
|
||||||
|
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||||
|
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||||
|
size_k
|
||||||
|
], f"Unsupported groupsize = {group_size}"
|
||||||
|
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = size_k
|
||||||
|
assert group_size <= size_k
|
||||||
|
|
||||||
|
if group_size < size_k:
|
||||||
|
# Reshape to [groupsize, -1]
|
||||||
|
w = w.reshape((-1, group_size, size_n))
|
||||||
|
w = w.permute(1, 0, 2)
|
||||||
|
w = w.reshape((group_size, -1))
|
||||||
|
|
||||||
|
max_q_val = 2**num_bits - 1
|
||||||
|
half_q_val = (max_q_val + 1) // 2
|
||||||
|
|
||||||
|
# Compute scale for each group
|
||||||
|
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
||||||
|
s_group *= 2 / max_q_val # 2 => symmetric
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
q_w = torch.round(w / s_group).int()
|
||||||
|
q_w += half_q_val
|
||||||
|
q_w = torch.clamp(q_w, 0, max_q_val)
|
||||||
|
# Compute ref (dequantized)
|
||||||
|
w_ref = (q_w - half_q_val).half() * s_group
|
||||||
|
|
||||||
|
# Restore original shapes
|
||||||
|
def reshape_w(w):
|
||||||
|
w = w.reshape((group_size, -1, size_n))
|
||||||
|
w = w.permute(1, 0, 2)
|
||||||
|
w = w.reshape((size_k, size_n)).contiguous()
|
||||||
|
return w
|
||||||
|
|
||||||
|
q_w = reshape_w(q_w)
|
||||||
|
w_ref = reshape_w(w_ref)
|
||||||
|
|
||||||
|
# Compute int8 quantization scale for each channel
|
||||||
|
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
||||||
|
s_channel /= 127.0
|
||||||
|
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
||||||
|
w_ref = t_int8.half() * s_channel
|
||||||
|
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
||||||
|
|
||||||
|
# Fuse scales
|
||||||
|
s_group = (s_group.reshape(-1, size_n).contiguous() /
|
||||||
|
s_channel).to(dtype=torch.half)
|
||||||
|
else:
|
||||||
|
max_q_val = 2**(num_bits - 1) - 1
|
||||||
|
|
||||||
|
# Compute scale for each channel
|
||||||
|
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
||||||
|
s_channel /= max_q_val
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
q_w = torch.round(w / s_channel).int()
|
||||||
|
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
||||||
|
# Compute ref (dequantized)
|
||||||
|
w_ref = q_w.half() * s_channel
|
||||||
|
|
||||||
|
s_group = torch.tensor([], dtype=torch.half)
|
||||||
|
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
||||||
|
s_channel /= (2**(8 - num_bits))
|
||||||
|
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
||||||
|
|
||||||
|
return (
|
||||||
|
w_ref.to(device=orig_device),
|
||||||
|
q_w.to(device=orig_device),
|
||||||
|
s_group.to(device=orig_device),
|
||||||
|
s_channel.to(device=orig_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||||
orig_device = q_w.device
|
orig_device = q_w.device
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user