[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)
This commit is contained in:
parent
25e778aa16
commit
396d92d5e0
@ -172,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
|
12
csrc/ops.h
12
csrc/ops.h
@ -89,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full);
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
|
@ -19,10 +19,10 @@
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#include "../gptq_marlin/gptq_marlin.cuh"
|
||||
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
|
||||
#include "../gptq_marlin/marlin.cuh"
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
|
||||
using namespace gptq_marlin;
|
||||
using namespace marlin;
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
||||
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
num_groups = b_scales.size(0);
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
||||
int min_workspace_size =
|
||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
marlin::max_par);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
||||
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
||||
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
marlin::max_par);
|
||||
} else {
|
||||
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
269
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
269
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
@ -0,0 +1,269 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
||||
first_n_packed + (n_id * 4)])));
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
constexpr int sh_stride = tile_n_ints;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
// Undo interleaving
|
||||
int cur_n_pos_unpacked;
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
} else {
|
||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||
sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
||||
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
" is not size_k = ", size_k);
|
||||
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4)
|
||||
CALL_IF(8)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
@ -19,8 +19,8 @@
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#include "gptq_marlin.cuh"
|
||||
#include "gptq_marlin_dtypes.cuh"
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@ -32,7 +32,7 @@ inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace gptq_marlin {
|
||||
namespace marlin {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
@ -72,10 +72,11 @@ __global__ void Marlin(
|
||||
} // namespace gptq_marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full) {
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
@ -264,6 +265,114 @@ dequant_8bit<nv_bfloat16>(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Zero-point dequantizers
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
|
||||
int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_4bit_zp<nv_bfloat16>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
||||
int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_8bit_zp<nv_bfloat16>(int q) {
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
template <typename scalar_t>
|
||||
@ -277,6 +386,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
|
||||
int i) {
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 zp =
|
||||
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
||||
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||
}
|
||||
|
||||
// Same as above, but for act_order (each K is multiplied individually)
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
@ -404,6 +524,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
@ -413,6 +534,8 @@ __global__ void Marlin(
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
@ -437,6 +560,7 @@ __global__ void Marlin(
|
||||
using FragB = typename ScalarType<scalar_t>::FragB;
|
||||
using FragC = typename ScalarType<scalar_t>::FragC;
|
||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
@ -566,6 +690,13 @@ __global__ void Marlin(
|
||||
int tb_n_warps = thread_n_blocks / 4;
|
||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||
|
||||
// Zero-points sizes/strides
|
||||
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
constexpr int zp_tb_groups = s_tb_groups;
|
||||
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
|
||||
// Global A read index of current thread.
|
||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
(threadIdx.x % a_gl_rd_delta_o);
|
||||
@ -605,6 +736,19 @@ __global__ void Marlin(
|
||||
int s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
int zp_gl_rd;
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (group_blocks == -1) {
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
@ -616,6 +760,18 @@ __global__ void Marlin(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
constexpr int num_row_threads = 4;
|
||||
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||
int zp_sh_rd;
|
||||
if constexpr (has_zp) {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
}
|
||||
|
||||
// Precompute which thread should not read memory in which iterations; this is
|
||||
// needed if there are more threads than required for a certain tilesize or
|
||||
// when the batchsize is not a multiple of 16.
|
||||
@ -664,14 +820,17 @@ __global__ void Marlin(
|
||||
int4* sh_a = sh;
|
||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
@ -777,6 +936,28 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks != -1) {
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch zero-points if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < zp_tb_groups; i++) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||
&zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||
@ -784,6 +965,12 @@ __global__ void Marlin(
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto fetch_zp_to_shared = [&]() {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
};
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
@ -932,8 +1119,73 @@ __global__ void Marlin(
|
||||
}
|
||||
};
|
||||
|
||||
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||
if constexpr (!has_zp) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||
}
|
||||
|
||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
int cur_k = warp_row * 16;
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
if constexpr (has_zp) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
if constexpr (num_bits == 4) {
|
||||
int zp_quant = frag_qzp[k % 2][0];
|
||||
int zp_quant_shift = zp_quant >> 8;
|
||||
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
|
||||
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
|
||||
|
||||
} else {
|
||||
int zp_quant_0 = frag_qzp[k % 2][0];
|
||||
int zp_quant_1 = frag_qzp[k % 2][1];
|
||||
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
|
||||
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
|
||||
}
|
||||
|
||||
frag_zp[0] = frag_zp_0[0];
|
||||
frag_zp[1] = frag_zp_0[1];
|
||||
frag_zp[2] = frag_zp_1[0];
|
||||
frag_zp[3] = frag_zp_1[1];
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -944,16 +1196,32 @@ __global__ void Marlin(
|
||||
int b_quant = frag_b_quant[k % 2][0][j];
|
||||
int b_quant_shift = b_quant >> 8;
|
||||
|
||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||
if constexpr (has_zp) {
|
||||
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
|
||||
|
||||
} else {
|
||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||
}
|
||||
|
||||
} else {
|
||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
|
||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||
if constexpr (has_zp) {
|
||||
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
|
||||
} else {
|
||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
@ -967,6 +1235,11 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b1
|
||||
if constexpr (has_zp) {
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||
@ -1189,6 +1462,12 @@ __global__ void Marlin(
|
||||
}
|
||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_zp_to_shared();
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
}
|
||||
|
||||
@ -1197,6 +1476,7 @@ __global__ void Marlin(
|
||||
init_same_group(0);
|
||||
fetch_to_registers(0, 0);
|
||||
fetch_scales_to_registers(0, 0);
|
||||
fetch_zp_to_registers(0, 0);
|
||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||
};
|
||||
@ -1217,6 +1497,7 @@ __global__ void Marlin(
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
fetch_to_registers(k + 1, pipe % stages);
|
||||
fetch_scales_to_registers(k + 1, pipe);
|
||||
fetch_zp_to_registers(k + 1, pipe);
|
||||
if (k == b_sh_wr_iters - 2) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages);
|
||||
@ -1354,6 +1635,7 @@ __global__ void Marlin(
|
||||
|
||||
} else {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
start_pipes();
|
||||
@ -1363,22 +1645,24 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS) \
|
||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||
num_threads == NUM_THREADS) { \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS>, \
|
||||
HAS_ZP, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
||||
prob_k, locks); \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
|
||||
prob_m, prob_n, prob_k, locks); \
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@ -1548,39 +1832,61 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
return exec_config_t{0, {-1, -1, -1}};
|
||||
}
|
||||
|
||||
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
|
||||
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace, int num_bits,
|
||||
bool has_act_order, bool is_k_full, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||
int thread_n, int sms, int max_par) {
|
||||
bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int dev,
|
||||
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||
int max_par) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -1665,6 +1971,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
@ -1701,28 +2008,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||
}
|
||||
|
||||
// Define kernel configurations
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, 32, 2, 256)
|
||||
CALL_IF(4, 16, 4, 256)
|
||||
CALL_IF(4, 8, 8, 256)
|
||||
CALL_IF(4, 8, 4, 128)
|
||||
CALL_IF(4, 4, 8, 128)
|
||||
CALL_IF(8, 32, 2, 256)
|
||||
CALL_IF(8, 16, 4, 256)
|
||||
CALL_IF(8, 8, 8, 256)
|
||||
CALL_IF(8, 8, 4, 128)
|
||||
CALL_IF(8, 4, 8, 128)
|
||||
GPTQ_CALL_IF(4, 16, 4, 256)
|
||||
GPTQ_CALL_IF(4, 8, 8, 256)
|
||||
GPTQ_CALL_IF(4, 8, 4, 128)
|
||||
GPTQ_CALL_IF(4, 4, 8, 128)
|
||||
GPTQ_CALL_IF(8, 16, 4, 256)
|
||||
GPTQ_CALL_IF(8, 8, 8, 256)
|
||||
GPTQ_CALL_IF(8, 8, 4, 128)
|
||||
GPTQ_CALL_IF(8, 4, 8, 128)
|
||||
|
||||
AWQ_CALL_IF(4, 16, 4, 256)
|
||||
AWQ_CALL_IF(4, 8, 8, 256)
|
||||
AWQ_CALL_IF(4, 8, 4, 128)
|
||||
AWQ_CALL_IF(4, 4, 8, 128)
|
||||
AWQ_CALL_IF(8, 16, 4, 256)
|
||||
AWQ_CALL_IF(8, 8, 8, 256)
|
||||
AWQ_CALL_IF(8, 8, 4, 128)
|
||||
AWQ_CALL_IF(8, 4, 8, 128)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
||||
str(prob_n) + ", " + str(prob_k) + "]" +
|
||||
", has_act_order = " + str(has_act_order) +
|
||||
", num_groups = " + str(num_groups) +
|
||||
", group_size = " + str(group_size) +
|
||||
", thread_m_blocks = " + str(thread_m_blocks) +
|
||||
", thread_n_blocks = " + str(thread_n_blocks) +
|
||||
", thread_k_blocks = " + str(thread_k_blocks));
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
||||
@ -1733,10 +2045,11 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
} // namespace gptq_marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full) {
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
@ -1749,16 +2062,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
||||
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
@ -1772,6 +2084,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
|
||||
@ -1805,8 +2120,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
int group_size = -1;
|
||||
bool has_act_order = g_idx.size(0) != 0;
|
||||
|
||||
int b_rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(0);
|
||||
@ -1832,34 +2147,44 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||
"b_zeros dim 1 = ", b_scales.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
||||
int min_workspace_size =
|
||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
gptq_marlin::marlin_mm_f16i4<half>(
|
||||
marlin::marlin_mm_f16i4<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
|
||||
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, gptq_marlin::max_par);
|
||||
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
|
||||
is_k_full, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
@ -1,23 +1,16 @@
|
||||
#include "gptq_marlin.cuh"
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
#include "marlin.cuh"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void marlin_repack_kernel(
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
|
||||
#else
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void marlin_repack_kernel(
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
||||
NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out =
|
||||
torch::empty({size_k / gptq_marlin::tile_size,
|
||||
size_n * gptq_marlin::tile_size / pack_factor},
|
||||
options);
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
@ -9,7 +9,9 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace gptq_marlin {
|
||||
namespace marlin {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
// Repack params
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
// Helpers
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
@ -1,11 +1,11 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include "gptq_marlin.cuh"
|
||||
#include "marlin.cuh"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace gptq_marlin {
|
||||
namespace marlin {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
@ -23,6 +23,7 @@ class ScalarType<half> {
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
#endif
|
@ -30,7 +30,7 @@ inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace marlin {
|
||||
namespace marlin_dense {
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace marlin_dense
|
||||
|
||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
TORCH_CHECK(size_k == a.size(1),
|
||||
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
||||
", size_k = " + str(size_k));
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0,
|
||||
"size_k = " + str(size_k) +
|
||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
|
||||
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
||||
str(marlin_dense::tile_size));
|
||||
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = " +
|
||||
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
||||
", tile_size = " + str(marlin::tile_size));
|
||||
", tile_size = " + str(marlin_dense::tile_size));
|
||||
|
||||
// Verify N
|
||||
TORCH_CHECK(b_scales.size(1) == size_n,
|
||||
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
||||
", size_n = " + str(size_n));
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(1) % marlin_dense::tile_size == 0,
|
||||
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
||||
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
|
||||
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
|
||||
marlin_dense::pack_factor_4bit;
|
||||
TORCH_CHECK(
|
||||
size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||
@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
"Unexpected groupsize = " + str(groupsize));
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::min_thread_n == 0,
|
||||
"size_n = " + str(size_n) +
|
||||
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
|
||||
"size_n = " + str(size_n) +
|
||||
", is not divisible by min_thread_n = " +
|
||||
str(marlin_dense::min_thread_n));
|
||||
int min_workspace_size =
|
||||
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = " + str(workspace.numel()) +
|
||||
" is below min_workspace_size = " + str(min_workspace_size));
|
||||
|
||||
int dev = a.get_device();
|
||||
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
|
||||
sms, marlin::max_par);
|
||||
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, marlin_dense::max_par);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
@ -141,6 +141,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
|
||||
// awq_marlin repack from AWQ.
|
||||
ops.def("awq_marlin_repack", &awq_marlin_repack);
|
||||
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||
|
@ -12,16 +12,18 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
|
||||
marlin_permute_scales)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
|
||||
marlin_make_empty_g_idx, marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
pack_fp8_to_int32)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights)
|
||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||
marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
|
||||
sort_weights)
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
@ -57,12 +59,12 @@ def rand_data(shape, dtype=torch.float16):
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
mnk_factors):
|
||||
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
@ -120,12 +122,60 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize
|
||||
w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits,
|
||||
group_size)
|
||||
|
||||
# Pack to AWQ format
|
||||
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
def test_marlin_gemm(
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
@ -155,6 +205,8 @@ def test_marlin_gemm(
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, num_bits, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
@ -162,6 +214,7 @@ def test_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
@ -170,6 +223,7 @@ def test_marlin_gemm(
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
has_zp=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
@ -188,7 +242,8 @@ def test_marlin_gemm(
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
@ -301,3 +356,65 @@ def test_fp8_marlin_gemm(
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, num_bits, group_size)
|
||||
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
is_k_full = True
|
||||
has_zp = True
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
num_bits,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
has_zp,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
@ -44,9 +44,9 @@ MODEL_ARG_EXPTYPES = [
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
|
||||
|
||||
# AUTOAWQ
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq_marlin"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "awq_marlin"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
|
||||
]
|
||||
|
||||
|
@ -276,14 +276,22 @@ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
||||
num_bits)
|
||||
|
||||
|
||||
# gptq_marlin
|
||||
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
|
||||
num_bits: int) -> torch.Tensor:
|
||||
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, g_idx: torch.Tensor,
|
||||
perm: torch.Tensor, workspace: torch.Tensor,
|
||||
num_bits: int, size_m: int, size_n: int, size_k: int,
|
||||
is_k_full: bool) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
|
||||
workspace, num_bits, size_m, size_n,
|
||||
size_k, is_k_full)
|
||||
b_scales: torch.Tensor, b_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor, perm: torch.Tensor,
|
||||
workspace: torch.Tensor, num_bits: int, size_m: int,
|
||||
size_n: int, size_k: int, is_k_full: bool,
|
||||
has_zp: bool) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, num_bits,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
has_zp)
|
||||
|
||||
|
||||
# fp8 marlin
|
||||
|
@ -251,7 +251,7 @@ class ModelConfig:
|
||||
f"supported in ROCm.")
|
||||
if (self.quantization
|
||||
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
|
||||
"fbgemm_fp8", "compressed_tensors")):
|
||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors")):
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
|
@ -2,6 +2,7 @@ from typing import Dict, Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import (
|
||||
@ -31,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"marlin": MarlinConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
|
268
vllm/model_executor/layers/quantization/awq_marlin.py
Normal file
268
vllm/model_executor/layers/quantization/awq_marlin.py
Normal file
@ -0,0 +1,268 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points,
|
||||
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
|
||||
lm_head_quantized: bool) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.pack_factor = 32 // self.weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
verify_awq_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.has_zp)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"has_zp={self.has_zp}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_marlin"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_marlin for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMarlinLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
has_zp = quant_config.get("zero_point", None)
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
return check_awq_marlin_supported(
|
||||
num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
has_zp=has_zp,
|
||||
min_capability=cls.get_min_capability())
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.num_groups = num_groups
|
||||
|
||||
# TODO: Update this docs
|
||||
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size)
|
||||
replace_tensor(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
marlin_zp = awq_to_marlin_zero_points(
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_tensor(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_awq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
weight_zp=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_marlin_supported,
|
||||
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
@ -38,9 +38,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(num_bits=self.num_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=True)
|
||||
verify_gptq_marlin_supported(num_bits=self.num_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=True)
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
# ampere and up
|
||||
@ -135,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
# No zero-point
|
||||
layer.weight_zp = marlin_make_empty_g_idx(device)
|
||||
|
||||
# Repack weights from compressed-tensors format to marlin format.
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.weight_packed.t().contiguous(),
|
||||
@ -155,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
return apply_marlin_linear(
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_zp=layer.weight_zp,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
|
@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||
apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
verify_gptq_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -37,9 +37,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=self.is_sym)
|
||||
verify_gptq_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=self.is_sym)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
|
||||
|
||||
@ -105,22 +105,27 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
return check_marlin_supported(num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
min_capability=cls.get_min_capability())
|
||||
return check_gptq_marlin_supported(
|
||||
num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
min_capability=cls.get_min_capability())
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
@ -278,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
# No zero-point
|
||||
layer.zp = marlin_make_empty_g_idx(device)
|
||||
|
||||
# Repack weights from autogptq format to marlin format.
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.qweight,
|
||||
@ -302,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_marlin_linear(
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
weight_zp=layer.zp,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
|
@ -1,54 +1,92 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quant_utils import pack_cols, unpack_cols
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
|
||||
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: int) -> bool:
|
||||
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: Optional[int],
|
||||
has_zp: bool) -> Tuple[bool, Optional[str]]:
|
||||
if min_capability is not None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < min_capability:
|
||||
return (False, "Marlin does not support device_capability = {}"
|
||||
", the min_capability required is {}".format(
|
||||
device_capability, min_capability))
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < min_capability:
|
||||
return False
|
||||
if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
|
||||
return (False, "Marlin does not support weight_bits = {}. "
|
||||
"Only weight_bits = {} are supported.".format(
|
||||
num_bits, MARLIN_SUPPORTED_NUM_BITS))
|
||||
|
||||
return (device_capability >= min_capability
|
||||
and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
|
||||
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
and is_sym in GPTQ_MARLIN_SUPPORTED_SYM)
|
||||
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return (False, "Marlin does not support group_size = {}. Only "
|
||||
"group_sizes = {} are supported.".format(
|
||||
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
|
||||
|
||||
if not has_zp and not is_sym:
|
||||
return (False,
|
||||
"Marlin without zero_points must have symmetric quantization")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def verify_marlin_supported(num_bits: int, group_size: Optional[int],
|
||||
is_sym: bool) -> None:
|
||||
def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: int) -> bool:
|
||||
cond, _ = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
is_sym,
|
||||
min_capability,
|
||||
has_zp=False)
|
||||
return cond
|
||||
|
||||
if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"Marlin does not support weight_bits = {num_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if (group_size is None
|
||||
or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
|
||||
raise ValueError(
|
||||
f"Marlin does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"Marlin does not support is_sym = is_sym. "
|
||||
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
|
||||
|
||||
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
|
||||
min_capability: int) -> bool:
|
||||
cond, _ = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
False,
|
||||
min_capability,
|
||||
has_zp=has_zp)
|
||||
return cond
|
||||
|
||||
|
||||
def verify_gptq_marlin_supported(num_bits: int, group_size: int,
|
||||
is_sym: bool) -> None:
|
||||
cond, err_msg = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
is_sym,
|
||||
min_capability=None,
|
||||
has_zp=False)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError("GPTQ" + err_msg)
|
||||
|
||||
|
||||
def verify_awq_marlin_supported(num_bits: int, group_size: int,
|
||||
has_zp: bool) -> None:
|
||||
cond, err_msg = _check_marlin_supported(num_bits,
|
||||
group_size,
|
||||
False,
|
||||
min_capability=None,
|
||||
has_zp=has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError("AWQ" + err_msg)
|
||||
|
||||
|
||||
def verify_marlin_supports_shape(output_size_per_partition: int,
|
||||
@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
return s
|
||||
|
||||
|
||||
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
|
||||
num_bits: int) -> torch.Tensor:
|
||||
# Permute zero-points in a similar way to scales, but do not use the
|
||||
# "single" permutation, since zero-points are applied on every MMA
|
||||
scale_perm, _ = get_scale_perms()
|
||||
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
zp = pack_cols(zp, num_bits, size_k, size_n)
|
||||
|
||||
return zp
|
||||
|
||||
|
||||
def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
size_n: int, num_bits: int) -> torch.Tensor:
|
||||
# AWQ zero-points are quantized and packed on the column dim.
|
||||
# In addition, the values are permuted based on dequantizer.
|
||||
# Here we undo both of these, and then apply marlin permutation
|
||||
# and pack it back.
|
||||
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
||||
|
||||
# Undo interleaving (use argsort(..) to get inverse perm)
|
||||
if num_bits == 4:
|
||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
||||
elif num_bits == 8:
|
||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
||||
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
||||
return marlin_zp
|
||||
|
||||
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_tensor(layer: torch.nn.Module, name: str,
|
||||
@ -149,23 +232,26 @@ def replace_tensor(layer: torch.nn.Module, name: str,
|
||||
del new_t
|
||||
|
||||
|
||||
def apply_marlin_linear(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zp: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
@ -173,7 +259,43 @@ def apply_marlin_linear(input: torch.Tensor,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full)
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def apply_awq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zp: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
num_bits,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
@ -2,11 +2,13 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales
|
||||
from .quant_utils import get_pack_factor, quantize_weights, sort_weights
|
||||
from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
|
||||
marlin_zero_points)
|
||||
from .quant_utils import (get_pack_factor, quantize_weights,
|
||||
quantize_weights_with_zp, sort_weights)
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
@ -46,14 +48,14 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_w = q_w.cpu().numpy().astype(np.uint32)
|
||||
|
||||
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||
dtype=numpy.uint32)
|
||||
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||
dtype=np.uint32)
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
||||
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
|
||||
@ -74,12 +76,12 @@ def get_weight_perm(num_bits: int):
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm_list)
|
||||
perm = np.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
@ -118,3 +120,32 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
||||
|
||||
def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Detect num groups
|
||||
assert size_k % group_size == 0
|
||||
num_groups = size_k // group_size
|
||||
|
||||
# Quantize with zp
|
||||
w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size)
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
@ -106,6 +106,67 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
|
||||
orig_device = w.device
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
min_q_val = 0
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max = torch.max(w, 0, keepdim=True)[0]
|
||||
min = torch.min(w, 0, keepdim=True)[0]
|
||||
s = (max - min).clamp(min=1e-5) / max_q_val
|
||||
|
||||
# Compute zero-point for each group
|
||||
zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int() + zp
|
||||
q_w = torch.clamp(q_w, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - zp).half() * s
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
zp.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
orig_device = q_w.device
|
||||
|
||||
@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
)
|
||||
|
||||
|
||||
def gptq_pack(
|
||||
def pack_rows(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
@ -144,3 +205,90 @@ def gptq_pack(
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
return q_res
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k, size_n // pack_factor
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def gptq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
return pack_rows(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def awq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
Loading…
x
Reference in New Issue
Block a user