Add GPTQ support (#916)

This commit is contained in:
CHU Tianxiang 2023-12-15 19:04:22 +08:00 committed by GitHub
parent c06170cc8e
commit 0fbfc4b81b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1782 additions and 82 deletions

View File

@ -84,7 +84,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)

View File

@ -244,7 +244,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",

View File

@ -77,3 +77,15 @@ void squeezellm_gemm(
torch::Tensor mat, torch::Tensor mat,
torch::Tensor mul, torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama);
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);

View File

@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops // Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
// Cache ops // Cache ops

View File

@ -0,0 +1,64 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace vllm {
namespace gptq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
} // namespace gptq
} // namespace vllm
#endif

View File

@ -0,0 +1,151 @@
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
{
half2* ptr = (half2*) item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
{
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
} // namespace gptq
} // namespace vllm
#endif

View File

@ -0,0 +1,859 @@
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
*/
#include <cstdint>
#include <cstdio>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "compat.cuh"
#include "matrix_view.cuh"
#include "qdq_4.cuh"
namespace vllm {
namespace gptq {
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_Q_GEMM_ROWS 50
#define MAX_ALT_GEMM_ROWS 8
#define THREADS_X 32
#define THREADS_Y 32
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#if defined(USE_ROCM)
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hadd2(result, g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int*
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
// Column result
float block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][4];
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}
void gemm_half_q_half_cuda_part
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
int groups
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
kernel<<<gridDim, blockDim>>>
(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
c,
size_m,
size_n,
size_k,
groups,
b_q_perm
);
}
__global__ void reconstruct_exllama_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
void reconstruct_exllama
(
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* out,
int height,
int width,
int groups
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_exllama_kernel<<<gridDim, blockDim>>>
(
b_q_weight,
b_q_perm,
b_gptq_qzeros,
b_gptq_scales,
height,
width,
groups,
out
);
}
__global__ void gemm_half_q_half_alt_kernel(
const half2* __restrict__ vec,
const uint32_t* __restrict__ mat,
half* __restrict__ mul,
const half* __restrict__ scales,
const uint32_t* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int height,
int width
)
{
int zero_width = width / 8;
int vec_height = height * 4;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
blockvec[m][threadIdx.x] =
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
threadIdx.x];
}
}
__shared__ half2 deq2[256][8];
int val = threadIdx.x / 8;
int off = threadIdx.x % 8;
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
deq2[val][off] = __halves2half2(
__int2half_rn(val & 0xF), __int2half_rn(val >> 4)
);
}
if (blockIdx.z == 0)
{
for (int m = 0; m < b_end; m++)
mul[(b + m) * width + w] = __int2half_rn(0);
}
__syncthreads();
int i = width * h + w;
int g_h = h * 8;
int k = 0;
int z_w = w / 8;
int z_mod = (w % 8) * 4;
half2 res2;
half res[BLOCK_M_SIZE_MAX] = {};
unsigned int tmp;
while (k < h_end) {
tmp = mat[i];
half2 scales_tmp[4];
half2 zeros_tmp[4];
for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
int g = g_idx[g_h + (k + tmp_k) * 2];
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
half scale_f = scales[g * width + w];
half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2(
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
);
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
for (int m = 0; m < b_end; m++) {
res2 = {};
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
}
i += width;
k += 4;
}
for (int m = 0; m < b_end; m++) {
atomicAdd(&mul[(b + m) * width + w], res[m]);
}
}
void gemm_half_q_half_alt
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_g_idx,
half* c,
int size_m,
int size_n,
int size_k
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
(
(const half2*) a,
b_q_weight,
c,
b_gptq_scales,
b_gptq_qzeros,
b_g_idx,
size_m,
size_k / 8,
size_n
);
}
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int* __restrict__ g_idx,
const int height,
const int width,
const int group,
half* __restrict__ out
)
{
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q4_row w_zeros_(w_zeros, group, width);
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
int group = g_idx[row + s / 4];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void reconstruct_gptq
(
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_g_idx,
half* out,
int height,
int width,
int groups
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, 8);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
(
b_q_weight,
b_gptq_scales,
b_gptq_qzeros,
b_g_idx,
height,
width,
groups,
out
);
}
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_g_idx,
half* c,
half* temp_dq,
int size_m,
int size_n,
int size_k,
int groups,
bool use_exllama
)
{
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
size_k, size_n, groups);
}
else
{
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups);
}
const half alpha = __float2half(1.0f);
const half beta = __float2half(0.0f);
cublasHgemm(cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
size_n, size_m, size_k,
&alpha, temp_dq, size_n,
a, size_k,
&beta, c, size_n);
}
else if (use_exllama)
{
// Quantized matmul
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int last_chunk_size = size_m - last_chunk;
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
groups);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
b_gptq_scales, b_g_idx, c + last_chunk * size_n,
last_chunk_size, size_n, size_k, last_chunk_size,
groups);
}
}
else
{
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k);
}
}
__global__ void shuffle_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
}
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void shuffle_exllama_weight
(
uint32_t* q_weight,
int* q_perm,
int height,
int width
)
{
if (q_perm)
{
uint32_t* new_qweight = NULL;
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
(
q_weight,
new_qweight,
q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(new_qweight);
}
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
}
} // namespace gptq
} // namespace vllm
torch::Tensor gptq_gemm
(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
vllm::gptq::gemm_half_q_half_cuda
(
at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(),
(const uint32_t*) b_q_weight.data_ptr(),
(const uint32_t*)b_gptq_qzeros.data_ptr(),
(const half*) b_gptq_scales.data_ptr(),
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
(half*) c.data_ptr(),
(half*) temp_dq.data_ptr(),
c.size(0), // m
c.size(1), // n
a.size(1), // k
b_gptq_qzeros.size(0), // group number
use_exllama
);
return c;
}
void gptq_shuffle
(
torch::Tensor q_weight,
torch::Tensor q_perm
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 8,
q_weight.size(1)
);
}

View File

@ -0,0 +1,235 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace gptq
} // namespace vllm
#else
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
} // namespace gptq
} // namespace vllm
#endif

View File

@ -0,0 +1,60 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace vllm {
namespace gptq {
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace gptq
} // namespace vllm
#endif

View File

@ -219,6 +219,7 @@ vllm_extension_sources = [
"csrc/activation_kernels.cu", "csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu", "csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu", "csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp", "csrc/pybind.cpp",
] ]

View File

@ -142,7 +142,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"] supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq"] rocm_not_supported_quantization = ["awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()

View File

@ -179,7 +179,7 @@ class EngineArgs:
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None, default=None,
help='Method used to quantize the weights') help='Method used to quantize the weights')
return parser return parser

View File

@ -38,8 +38,9 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
quantization: The method used to quantize the model weights. Currently, quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not we support "awq", "gptq" and "squeezellm". If None, we assume the
quantized and use `dtype` to determine the data type of the weights. model weights are not quantized and use `dtype` to determine the
data type of the weights.
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
@abstractmethod @abstractmethod
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Create weights for a linear layer.""" """Create weights for a linear layer."""
raise NotImplementedError raise NotImplementedError
@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
weight = Parameter(torch.empty(output_size, output_size: int,
input_size, params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
@ -102,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size, self.params_dtype) self.input_size, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items(): for name, weight in self.linear_weights.items():
self.register_parameter(name, weight) if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size,
@ -168,10 +174,12 @@ class ColumnParallelLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size_per_partition, self.params_dtype) self.input_size, self.output_size_per_partition, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items(): for name, weight in self.linear_weights.items():
self.register_parameter(name, weight) if isinstance(weight, torch.Tensor):
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
@ -295,10 +303,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
else: else:
logger.warning( ignore_warning = getattr(param, "ignore_warning", False)
"Loading a weight without `output_dim` attribute in " if not ignore_warning:
"MergedColumnParallelLinear, assume the weight is " logger.warning(
"the same for all partitions.") "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
@ -418,10 +428,12 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
else: else:
logger.warning( ignore_warning = getattr(param, "ignore_warning", False)
"Loading a weight without `output_dim` attribute in " if not ignore_warning:
"QKVParallelLinear, assume the weight is the same " logger.warning(
"for all partitions.") "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
@ -481,10 +493,12 @@ class RowParallelLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_weights = self.linear_method.create_weights(
self.input_size_per_partition, self.output_size, self.params_dtype) self.input_size_per_partition, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items(): for name, weight in self.linear_weights.items():
self.register_parameter(name, weight) if isinstance(weight, torch.Tensor):
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "

View File

@ -1,11 +1,13 @@
from typing import Type from typing import Type
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
} }

View File

@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
if input_size % self.quant_config.group_size != 0: output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase):
qweight = Parameter( qweight = Parameter(
torch.empty( torch.empty(
input_size, input_size_per_partition,
output_size // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase):
}) })
qzeros = Parameter( qzeros = Parameter(
torch.empty( torch.empty(
input_size // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase):
}) })
scales = Parameter( scales = Parameter(
torch.empty( torch.empty(
input_size // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size, output_size_per_partition,
device="cuda", device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase):
} }
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], weights: Dict[str, Any],
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = weights["qweight"]

View File

@ -0,0 +1,215 @@
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.pack_factor = 32 // self.weight_bits
# exllama kernel v1 only supports 4 bit
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
def get_linear_method(self) -> "GPTQLinearMethod":
return GPTQLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class ExllamaState(Enum):
UNUSED = enum.auto()
UNINITIALIZED = enum.auto()
READY = enum.auto()
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.
Args:
quant_config: The GPTQ quantization config.
"""
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if input_size != input_size_per_partition and self.quant_config.group_size != -1:
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
g_idx = Parameter(
torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
qzeros = Parameter(
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
scale_and_zero_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
})
return {
"qweight": qweight,
"g_idx": g_idx,
"qzeros": qzeros,
"scales": scales,
"exllama_state": exllama_state,
}
def apply_weights(self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
torch.int)
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY)
if bias is not None:
output = output + bias
return output.reshape(out_shape)

View File

@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def __init__(self, quant_config: SqueezeLLMConfig): def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
if input_size % self.quant_config.pack_factor != 0: output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
qweight = Parameter( qweight = Parameter(
torch.empty( torch.empty(
input_size // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size, output_size_per_partition,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
} }
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], weights: Dict[str, Any],
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = weights["qweight"]

View File

@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module):
continue continue
if "word_embeddings" in name: if "word_embeddings" in name:
name = name.replace(".word_embeddings", "") name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -425,27 +425,32 @@ class FalconForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view( if output_dim is not None:
loaded_weight_shape[:output_dim] + loaded_weight = loaded_weight.view(
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + loaded_weight_shape[:output_dim] +
loaded_weight_shape[output_dim + 1:]) (total_num_kv_heads, num_query_heads_per_kv_head + 2,
wq = loaded_weight.narrow( -1) + loaded_weight_shape[output_dim + 1:])
output_dim + 1, 0, num_query_heads_per_kv_head).reshape( wq = loaded_weight.narrow(
*loaded_weight_shape[:output_dim], -1, output_dim + 1, 0,
*loaded_weight_shape[output_dim + 1:]) num_query_heads_per_kv_head).reshape(
wk = loaded_weight.narrow( *loaded_weight_shape[:output_dim], -1,
output_dim + 1, num_query_heads_per_kv_head, *loaded_weight_shape[output_dim + 1:])
1).reshape(*loaded_weight_shape[:output_dim], -1, wk = loaded_weight.narrow(
*loaded_weight_shape[output_dim + 1:]) output_dim + 1, num_query_heads_per_kv_head,
wv = loaded_weight.narrow( 1).reshape(*loaded_weight_shape[:output_dim], -1,
output_dim + 1, num_query_heads_per_kv_head + 1, *loaded_weight_shape[output_dim + 1:])
1).reshape(*loaded_weight_shape[:output_dim], -1, wv = loaded_weight.narrow(
*loaded_weight_shape[output_dim + 1:]) output_dim + 1, num_query_heads_per_kv_head + 1,
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) 1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module):
if not name.endswith(".weight"): if not name.endswith(".weight"):
continue continue
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)

View File

@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module):
config.hidden_size, config.hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0

View File

@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -153,7 +153,7 @@ class MixtralMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
linear_method=linear_method) linear_method=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape
@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136 # pylint: disable=E1136
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",

View File

@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.c_attn = QKVParallelLinear( self.c_attn = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -287,4 +287,5 @@ def initialize_dummy_weights(
values between -1e-3 and 1e-3 works well for most models. values between -1e-3 and 1e-3 works well for most models.
""" """
for param in model.state_dict().values(): for param in model.state_dict().values():
param.data.uniform_(low, high) if torch.is_floating_point(param):
param.data.uniform_(low, high)