Add GPTQ support (#916)
This commit is contained in:
parent
c06170cc8e
commit
0fbfc4b81b
@ -84,7 +84,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', 'squeezellm', None],
|
choices=['awq', 'gptq', 'squeezellm', None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
|
@ -244,7 +244,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', 'squeezellm', None],
|
choices=['awq', 'gptq', 'squeezellm', None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
|
12
csrc/ops.h
12
csrc/ops.h
@ -77,3 +77,15 @@ void squeezellm_gemm(
|
|||||||
torch::Tensor mat,
|
torch::Tensor mat,
|
||||||
torch::Tensor mul,
|
torch::Tensor mul,
|
||||||
torch::Tensor lookup_table);
|
torch::Tensor lookup_table);
|
||||||
|
|
||||||
|
torch::Tensor gptq_gemm(
|
||||||
|
torch::Tensor a,
|
||||||
|
torch::Tensor b_q_weight,
|
||||||
|
torch::Tensor b_gptq_qzeros,
|
||||||
|
torch::Tensor b_gptq_scales,
|
||||||
|
torch::Tensor b_g_idx,
|
||||||
|
bool use_exllama);
|
||||||
|
|
||||||
|
void gptq_shuffle(
|
||||||
|
torch::Tensor q_weight,
|
||||||
|
torch::Tensor q_perm);
|
||||||
|
@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
// Quantization ops
|
// Quantization ops
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
#endif
|
#endif
|
||||||
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
|
64
csrc/quantization/gptq/compat.cuh
Normal file
64
csrc/quantization/gptq/compat.cuh
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _compat_cuh
|
||||||
|
#define _compat_cuh
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||||
|
{
|
||||||
|
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
__half_raw hsum;
|
||||||
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
|
half tmpres = __hadd(hsum, val);
|
||||||
|
hsum = __half_raw(tmpres);
|
||||||
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||||
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||||
|
{
|
||||||
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
half2 old_val = *((half2*)&old);
|
||||||
|
half2 new_val = __hadd2(old_val, val);
|
||||||
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
#endif
|
151
csrc/quantization/gptq/matrix_view.cuh
Normal file
151
csrc/quantization/gptq/matrix_view.cuh
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _matrix_view_cuh
|
||||||
|
#define _matrix_view_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
class MatrixView_half
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*) item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __low2half(i01);
|
||||||
|
items[1] = __high2half(i01);
|
||||||
|
items[2] = __low2half(i23);
|
||||||
|
items[3] = __high2half(i23);
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __half2float(__low2half(i01));
|
||||||
|
items[1] = __half2float(__high2half(i01));
|
||||||
|
items[2] = __half2float(__low2half(i23));
|
||||||
|
items[3] = __half2float(__high2half(i23));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __half2half2(__low2half(i01));
|
||||||
|
items[1] = __half2half2(__high2half(i01));
|
||||||
|
items[2] = __half2half2(__low2half(i23));
|
||||||
|
items[3] = __half2half2(__high2half(i23));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_half_rw
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||||
|
|
||||||
|
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||||
|
{
|
||||||
|
half2 v01 = __halves2half2(v0, v1);
|
||||||
|
half2 v23 = __halves2half2(v2, v3);
|
||||||
|
half2* ptr = (half2*) item_ptr(row, column);
|
||||||
|
ptr[0] = v01;
|
||||||
|
ptr[1] = v23;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
|
items[0] = d & 0x0f;
|
||||||
|
items[1] = (d >> 4) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
|
items[0] = d & 0x0f;
|
||||||
|
items[1] = (d >> 4) & 0x0f;
|
||||||
|
items[2] = (d >> 8) & 0x0f;
|
||||||
|
items[3] = (d >> 12) & 0x0f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_column
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (row & 0x07) * 4;
|
||||||
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
#endif
|
859
csrc/quantization/gptq/q_gemm.cu
Normal file
859
csrc/quantization/gptq/q_gemm.cu
Normal file
@ -0,0 +1,859 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include "compat.cuh"
|
||||||
|
#include "matrix_view.cuh"
|
||||||
|
#include "qdq_4.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
#define BLOCK_KN_SIZE 128
|
||||||
|
#define BLOCK_M_SIZE_MAX 8
|
||||||
|
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||||
|
#define MAX_Q_GEMM_ROWS 50
|
||||||
|
#define MAX_ALT_GEMM_ROWS 8
|
||||||
|
#define THREADS_X 32
|
||||||
|
#define THREADS_Y 32
|
||||||
|
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||||
|
hipblasOperation_t transA,
|
||||||
|
hipblasOperation_t transB,
|
||||||
|
int m,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
const half* alpha,
|
||||||
|
const half* AP,
|
||||||
|
int lda,
|
||||||
|
const half* BP,
|
||||||
|
int ldb,
|
||||||
|
const half* beta,
|
||||||
|
half* CP,
|
||||||
|
int ldc) {
|
||||||
|
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||||
|
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(beta),
|
||||||
|
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||||
|
}
|
||||||
|
#define hipblasHgemm __compat_hipblasHgemm
|
||||||
|
|
||||||
|
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||||
|
#define rocblas_operation_none HIPBLAS_OP_N
|
||||||
|
#define rocblas_hgemm __compat_hipblasHgemm
|
||||||
|
#endif
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __hadd2(result, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||||
|
(
|
||||||
|
const half*,
|
||||||
|
const uint32_t*,
|
||||||
|
const uint32_t*,
|
||||||
|
const half*,
|
||||||
|
half*,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int*
|
||||||
|
);
|
||||||
|
|
||||||
|
template <bool first_block, int m_count>
|
||||||
|
__global__ void gemm_half_q_half_gptq_kernel
|
||||||
|
(
|
||||||
|
const half* __restrict__ a,
|
||||||
|
const uint32_t* __restrict__ b_q_weight,
|
||||||
|
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||||
|
const half* __restrict__ b_gptq_scales,
|
||||||
|
half* __restrict__ c,
|
||||||
|
const int size_m,
|
||||||
|
const int size_n,
|
||||||
|
const int size_k,
|
||||||
|
const int groups,
|
||||||
|
const int* __restrict__ b_q_perm
|
||||||
|
)
|
||||||
|
{
|
||||||
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
|
MatrixView_half_rw c_(c, size_m, size_n);
|
||||||
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
|
int t = threadIdx.x;
|
||||||
|
|
||||||
|
// Block
|
||||||
|
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
|
int offset_m = blockIdx.y * m_count;
|
||||||
|
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
|
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Preload block_a
|
||||||
|
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||||
|
|
||||||
|
if (offset_k + t < end_k)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||||
|
half* block_a_ptr = block_a[m];
|
||||||
|
|
||||||
|
half a0;
|
||||||
|
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||||
|
else a0 = a_ptr[offset_k + t];
|
||||||
|
block_a_ptr[t] = a0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero output
|
||||||
|
if (n >= size_n) return;
|
||||||
|
|
||||||
|
if (blockIdx.z == 0)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Find initial group
|
||||||
|
int groupsize = size_k / groups;
|
||||||
|
int group = offset_k / groupsize;
|
||||||
|
int nextgroup = offset_k + groupsize;
|
||||||
|
|
||||||
|
// a, b offset
|
||||||
|
int qk = offset_k / (32 / 4);
|
||||||
|
|
||||||
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
|
const half* a_ptr = &block_a[0][0];
|
||||||
|
int a_stride = BLOCK_KN_SIZE;
|
||||||
|
|
||||||
|
// Initial group
|
||||||
|
int zeros[4];
|
||||||
|
float scales[4];
|
||||||
|
half2 z1z16[4][2];
|
||||||
|
half2 y1y16[4][2];
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_f(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
|
// Column result
|
||||||
|
float block_c[m_count][4] = {};
|
||||||
|
|
||||||
|
// Dequantize and multiply
|
||||||
|
int k = offset_k;
|
||||||
|
while (k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
nextgroup += groupsize;
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_f(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
const int4* b_ptr4 = (int4*) b_ptr;
|
||||||
|
int4 load_int4 = *b_ptr4;
|
||||||
|
|
||||||
|
half2 dq[4][4];
|
||||||
|
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||||
|
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||||
|
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||||
|
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
b_ptr += size_n;
|
||||||
|
a_ptr += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||||
|
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||||
|
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||||
|
atomicAdd(out , result01);
|
||||||
|
atomicAdd(out + 1, result23);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||||
|
{
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void gemm_half_q_half_cuda_part
|
||||||
|
(
|
||||||
|
const half* a,
|
||||||
|
const uint32_t* b_q_weight,
|
||||||
|
const uint32_t* b_gptq_qzeros,
|
||||||
|
const half* b_gptq_scales,
|
||||||
|
const int* b_q_perm,
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n,
|
||||||
|
int size_k,
|
||||||
|
int m_count,
|
||||||
|
int groups
|
||||||
|
)
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
blockDim.z = 1;
|
||||||
|
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||||
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
|
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||||
|
|
||||||
|
kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
a,
|
||||||
|
b_q_weight,
|
||||||
|
b_gptq_qzeros,
|
||||||
|
b_gptq_scales,
|
||||||
|
c,
|
||||||
|
size_m,
|
||||||
|
size_n,
|
||||||
|
size_k,
|
||||||
|
groups,
|
||||||
|
b_q_perm
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void reconstruct_exllama_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ b_q_weight,
|
||||||
|
const int* __restrict__ b_q_perm,
|
||||||
|
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||||
|
const half* __restrict__ b_gptq_scales,
|
||||||
|
const int size_k,
|
||||||
|
const int size_n,
|
||||||
|
const int groups,
|
||||||
|
half* __restrict__ b
|
||||||
|
)
|
||||||
|
{
|
||||||
|
MatrixView_half_rw b_(b, size_k, size_n);
|
||||||
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
|
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
|
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||||
|
|
||||||
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
|
// Preload remapping table
|
||||||
|
__shared__ int perm[BLOCK_KN_SIZE];
|
||||||
|
int t = threadIdx.x;
|
||||||
|
|
||||||
|
if (b_q_perm)
|
||||||
|
{
|
||||||
|
if (offset_k + t < size_k)
|
||||||
|
perm[t] = b_q_perm[offset_k + t];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column
|
||||||
|
int n = offset_n + t * 4;
|
||||||
|
if (n >= size_n) return;
|
||||||
|
|
||||||
|
// Find initial group
|
||||||
|
int groupsize = size_k / groups;
|
||||||
|
int group = offset_k / groupsize;
|
||||||
|
int nextgroup = offset_k + groupsize;
|
||||||
|
|
||||||
|
// b offset
|
||||||
|
int qk = offset_k / (32 / 4);
|
||||||
|
|
||||||
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
|
|
||||||
|
// Initial zeros/scale
|
||||||
|
int zeros[4];
|
||||||
|
half2 scales[4];
|
||||||
|
half2 z1z16[4][2];
|
||||||
|
half2 y1y16[4][2];
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int k = offset_k;
|
||||||
|
int lk = 0;
|
||||||
|
|
||||||
|
while (k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
nextgroup += groupsize;
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int p = 0; p < 4; p++)
|
||||||
|
{
|
||||||
|
half2 dq[4][4];
|
||||||
|
const int4* b_ptr4 = (int4*) b_ptr;
|
||||||
|
int4 load_int4 = *b_ptr4;
|
||||||
|
|
||||||
|
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||||
|
|
||||||
|
b_ptr += size_n;
|
||||||
|
//half* dqh = (half*)dq;
|
||||||
|
if (b_q_perm)
|
||||||
|
{
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||||
|
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||||
|
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||||
|
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||||
|
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void reconstruct_exllama
|
||||||
|
(
|
||||||
|
const uint32_t* b_q_weight,
|
||||||
|
const uint32_t* b_gptq_qzeros,
|
||||||
|
const half* b_gptq_scales,
|
||||||
|
const int* b_q_perm,
|
||||||
|
half* out,
|
||||||
|
int height,
|
||||||
|
int width,
|
||||||
|
int groups
|
||||||
|
)
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||||
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
|
reconstruct_exllama_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
b_q_weight,
|
||||||
|
b_q_perm,
|
||||||
|
b_gptq_qzeros,
|
||||||
|
b_gptq_scales,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groups,
|
||||||
|
out
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void gemm_half_q_half_alt_kernel(
|
||||||
|
const half2* __restrict__ vec,
|
||||||
|
const uint32_t* __restrict__ mat,
|
||||||
|
half* __restrict__ mul,
|
||||||
|
const half* __restrict__ scales,
|
||||||
|
const uint32_t* __restrict__ zeros,
|
||||||
|
const int* __restrict__ g_idx,
|
||||||
|
int batch,
|
||||||
|
int height,
|
||||||
|
int width
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int zero_width = width / 8;
|
||||||
|
int vec_height = height * 4;
|
||||||
|
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||||
|
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||||
|
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||||
|
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||||
|
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||||
|
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||||
|
if (threadIdx.x < h_end) {
|
||||||
|
for (int m = 0; m < b_end; ++m) {
|
||||||
|
blockvec[m][threadIdx.x] =
|
||||||
|
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
|
||||||
|
threadIdx.x];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__shared__ half2 deq2[256][8];
|
||||||
|
int val = threadIdx.x / 8;
|
||||||
|
int off = threadIdx.x % 8;
|
||||||
|
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
||||||
|
deq2[val][off] = __halves2half2(
|
||||||
|
__int2half_rn(val & 0xF), __int2half_rn(val >> 4)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blockIdx.z == 0)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < b_end; m++)
|
||||||
|
mul[(b + m) * width + w] = __int2half_rn(0);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int i = width * h + w;
|
||||||
|
int g_h = h * 8;
|
||||||
|
int k = 0;
|
||||||
|
int z_w = w / 8;
|
||||||
|
int z_mod = (w % 8) * 4;
|
||||||
|
half2 res2;
|
||||||
|
half res[BLOCK_M_SIZE_MAX] = {};
|
||||||
|
|
||||||
|
unsigned int tmp;
|
||||||
|
while (k < h_end) {
|
||||||
|
tmp = mat[i];
|
||||||
|
half2 scales_tmp[4];
|
||||||
|
half2 zeros_tmp[4];
|
||||||
|
for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
|
||||||
|
int g = g_idx[g_h + (k + tmp_k) * 2];
|
||||||
|
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
|
||||||
|
half scale_f = scales[g * width + w];
|
||||||
|
half scale_f2 = scales[g2 * width + w];
|
||||||
|
half2 scale = __halves2half2(scale_f, scale_f2);
|
||||||
|
half2 zero = __halves2half2(
|
||||||
|
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
|
||||||
|
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
|
||||||
|
);
|
||||||
|
scales_tmp[tmp_k] = scale;
|
||||||
|
zeros_tmp[tmp_k] = zero;
|
||||||
|
}
|
||||||
|
for (int m = 0; m < b_end; m++) {
|
||||||
|
res2 = {};
|
||||||
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
|
||||||
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
|
||||||
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
|
||||||
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
|
||||||
|
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
|
||||||
|
}
|
||||||
|
i += width;
|
||||||
|
k += 4;
|
||||||
|
}
|
||||||
|
for (int m = 0; m < b_end; m++) {
|
||||||
|
atomicAdd(&mul[(b + m) * width + w], res[m]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void gemm_half_q_half_alt
|
||||||
|
(
|
||||||
|
const half* a,
|
||||||
|
const uint32_t* b_q_weight,
|
||||||
|
const uint32_t* b_gptq_qzeros,
|
||||||
|
const half* b_gptq_scales,
|
||||||
|
const int* b_g_idx,
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n,
|
||||||
|
int size_k
|
||||||
|
)
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
blockDim.z = 1;
|
||||||
|
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
|
||||||
|
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
|
||||||
|
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
|
gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
(const half2*) a,
|
||||||
|
b_q_weight,
|
||||||
|
c,
|
||||||
|
b_gptq_scales,
|
||||||
|
b_gptq_qzeros,
|
||||||
|
b_g_idx,
|
||||||
|
size_m,
|
||||||
|
size_k / 8,
|
||||||
|
size_n
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void reconstruct_gptq_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ w,
|
||||||
|
const half* __restrict__ w_scales,
|
||||||
|
const uint32_t* __restrict__ w_zeros,
|
||||||
|
const int* __restrict__ g_idx,
|
||||||
|
const int height,
|
||||||
|
const int width,
|
||||||
|
const int group,
|
||||||
|
half* __restrict__ out
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// Start of block
|
||||||
|
|
||||||
|
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||||
|
int row = blockIdx.y * 8;
|
||||||
|
if (column >= width) return;
|
||||||
|
|
||||||
|
// Views
|
||||||
|
|
||||||
|
MatrixView_q4_column w_(w, height, width);
|
||||||
|
MatrixView_half_rw out_(out, height, width);
|
||||||
|
MatrixView_half w_scales_(w_scales, group, width);
|
||||||
|
MatrixView_q4_row w_zeros_(w_zeros, group, width);
|
||||||
|
|
||||||
|
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||||
|
half* out_ptr = out_.item_ptr(row, column);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int s = 0; s < 32; s += 4)
|
||||||
|
{
|
||||||
|
int group = g_idx[row + s / 4];
|
||||||
|
half w_scale = w_scales_.item(group, column);
|
||||||
|
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||||
|
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||||
|
*out_ptr = w_item; out_ptr += out_.width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void reconstruct_gptq
|
||||||
|
(
|
||||||
|
const uint32_t* b_q_weight,
|
||||||
|
const uint32_t* b_gptq_qzeros,
|
||||||
|
const half* b_gptq_scales,
|
||||||
|
const int* b_g_idx,
|
||||||
|
half* out,
|
||||||
|
int height,
|
||||||
|
int width,
|
||||||
|
int groups
|
||||||
|
)
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.y = DIVIDE(height, 8);
|
||||||
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
|
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
b_q_weight,
|
||||||
|
b_gptq_scales,
|
||||||
|
b_gptq_qzeros,
|
||||||
|
b_g_idx,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groups,
|
||||||
|
out
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void gemm_half_q_half_cuda
|
||||||
|
(
|
||||||
|
cublasHandle_t cublas_handle,
|
||||||
|
const half* a,
|
||||||
|
const uint32_t* b_q_weight,
|
||||||
|
const uint32_t* b_gptq_qzeros,
|
||||||
|
const half* b_gptq_scales,
|
||||||
|
const int* b_g_idx,
|
||||||
|
half* c,
|
||||||
|
half* temp_dq,
|
||||||
|
int size_m,
|
||||||
|
int size_n,
|
||||||
|
int size_k,
|
||||||
|
int groups,
|
||||||
|
bool use_exllama
|
||||||
|
)
|
||||||
|
{
|
||||||
|
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
|
||||||
|
// Reconstruct FP16 matrix, then cuBLAS
|
||||||
|
if (use_exllama) {
|
||||||
|
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
|
||||||
|
size_k, size_n, groups);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||||
|
temp_dq, size_k, size_n, groups);
|
||||||
|
}
|
||||||
|
|
||||||
|
const half alpha = __float2half(1.0f);
|
||||||
|
const half beta = __float2half(0.0f);
|
||||||
|
cublasHgemm(cublas_handle,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
size_n, size_m, size_k,
|
||||||
|
&alpha, temp_dq, size_n,
|
||||||
|
a, size_k,
|
||||||
|
&beta, c, size_n);
|
||||||
|
}
|
||||||
|
else if (use_exllama)
|
||||||
|
{
|
||||||
|
// Quantized matmul
|
||||||
|
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||||
|
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||||
|
int last_chunk_size = size_m - last_chunk;
|
||||||
|
|
||||||
|
if (max_chunks)
|
||||||
|
{
|
||||||
|
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||||
|
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
|
||||||
|
groups);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (last_chunk_size)
|
||||||
|
{
|
||||||
|
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
|
||||||
|
b_gptq_scales, b_g_idx, c + last_chunk * size_n,
|
||||||
|
last_chunk_size, size_n, size_k, last_chunk_size,
|
||||||
|
groups);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||||
|
c, size_m, size_n, size_k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void shuffle_kernel
|
||||||
|
(
|
||||||
|
uint32_t* __restrict__ b_q_weight,
|
||||||
|
const int size_k,
|
||||||
|
const int size_n
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||||
|
if (n >= size_n) return;
|
||||||
|
int k = 0;
|
||||||
|
uint32_t* b_ptr = b_q_weight + n;
|
||||||
|
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void make_sequential_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ w,
|
||||||
|
uint32_t* __restrict__ w_new,
|
||||||
|
const int* __restrict__ q_perm,
|
||||||
|
const int w_height,
|
||||||
|
const int w_width
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint64_t* w2 = (uint64_t*) w;
|
||||||
|
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||||
|
int w2_stride = w_width >> 1;
|
||||||
|
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
|
if (w2_column >= w2_stride) return;
|
||||||
|
int w_new2_row = blockIdx.y;
|
||||||
|
int q_perm_idx = w_new2_row << 3;
|
||||||
|
uint64_t dst = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
{
|
||||||
|
int source_row = q_perm[q_perm_idx++];
|
||||||
|
|
||||||
|
int w2_row = source_row >> 3;
|
||||||
|
int w2_subrow = source_row & 0x07;
|
||||||
|
int w2_row_shift = w2_subrow << 2;
|
||||||
|
int wnew2_row_shift = i << 2;
|
||||||
|
|
||||||
|
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||||
|
src >>= w2_row_shift;
|
||||||
|
src &= 0x0000000f0000000f;
|
||||||
|
src <<= wnew2_row_shift;
|
||||||
|
dst |= src;
|
||||||
|
}
|
||||||
|
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void shuffle_exllama_weight
|
||||||
|
(
|
||||||
|
uint32_t* q_weight,
|
||||||
|
int* q_perm,
|
||||||
|
int height,
|
||||||
|
int width
|
||||||
|
)
|
||||||
|
{
|
||||||
|
if (q_perm)
|
||||||
|
{
|
||||||
|
uint32_t* new_qweight = NULL;
|
||||||
|
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||||
|
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = THREADS_X;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
|
gridDim.y = height / 8;
|
||||||
|
|
||||||
|
make_sequential_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
q_weight,
|
||||||
|
new_qweight,
|
||||||
|
q_perm,
|
||||||
|
height / 8,
|
||||||
|
width
|
||||||
|
);
|
||||||
|
// Replace qweights
|
||||||
|
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||||
|
// Cleanup
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
cudaFree(new_qweight);
|
||||||
|
}
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = THREADS_X;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
|
gridDim.y = 1;
|
||||||
|
shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
torch::Tensor gptq_gemm
|
||||||
|
(
|
||||||
|
torch::Tensor a,
|
||||||
|
torch::Tensor b_q_weight,
|
||||||
|
torch::Tensor b_gptq_qzeros,
|
||||||
|
torch::Tensor b_gptq_scales,
|
||||||
|
torch::Tensor b_g_idx,
|
||||||
|
bool use_exllama
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
|
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||||
|
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
|
||||||
|
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
|
||||||
|
|
||||||
|
vllm::gptq::gemm_half_q_half_cuda
|
||||||
|
(
|
||||||
|
at::cuda::getCurrentCUDABlasHandle(),
|
||||||
|
(const half*) a.data_ptr(),
|
||||||
|
(const uint32_t*) b_q_weight.data_ptr(),
|
||||||
|
(const uint32_t*)b_gptq_qzeros.data_ptr(),
|
||||||
|
(const half*) b_gptq_scales.data_ptr(),
|
||||||
|
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
|
||||||
|
(half*) c.data_ptr(),
|
||||||
|
(half*) temp_dq.data_ptr(),
|
||||||
|
c.size(0), // m
|
||||||
|
c.size(1), // n
|
||||||
|
a.size(1), // k
|
||||||
|
b_gptq_qzeros.size(0), // group number
|
||||||
|
use_exllama
|
||||||
|
);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void gptq_shuffle
|
||||||
|
(
|
||||||
|
torch::Tensor q_weight,
|
||||||
|
torch::Tensor q_perm
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||||
|
vllm::gptq::shuffle_exllama_weight(
|
||||||
|
(uint32_t*) q_weight.data_ptr(),
|
||||||
|
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
||||||
|
q_weight.size(0) * 8,
|
||||||
|
q_weight.size(1)
|
||||||
|
);
|
||||||
|
}
|
235
csrc/quantization/gptq/qdq_4.cuh
Normal file
235
csrc/quantization/gptq/qdq_4.cuh
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _qdq_4_cuh
|
||||||
|
#define _qdq_4_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// 77775555 33331111 66664444 22220000
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_4bit_8
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0];
|
||||||
|
uint32_t qb = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
uint32_t qa0 = qa & 0x0f;
|
||||||
|
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||||
|
qa >>= 8;
|
||||||
|
qb |= (qa1 << (i * 4 + 16));
|
||||||
|
qb |= (qa0 << (i * 4));
|
||||||
|
}
|
||||||
|
q[0] = qb;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
|
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||||
|
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||||
|
const half2 z1 = __halves2half2(z1_, z1_);
|
||||||
|
const half2 z16 = __halves2half2(z16_, z16_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||||
|
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
const half scale,
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
|
half2 scale2 = __half2half2(scale);
|
||||||
|
|
||||||
|
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||||
|
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||||
|
|
||||||
|
const half y1 = __float2half_rn(1.0f);
|
||||||
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||||
|
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
half2(&z1z16)[2],
|
||||||
|
half2(&y1y16)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
|
z1z16[0] = __half2half2(z1.as_half);
|
||||||
|
z1z16[1] = __half2half2(z16);
|
||||||
|
|
||||||
|
const half y1 = __float2half_rn(1.0f);
|
||||||
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __half2half2(y1);
|
||||||
|
y1y16[1] = __half2half2(y16);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2],
|
||||||
|
int stride,
|
||||||
|
bool scaled
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||||
|
|
||||||
|
if (scaled)
|
||||||
|
{
|
||||||
|
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||||
|
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
__forceinline__ __device__ void shuffle_4bit_8
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[8];
|
||||||
|
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
const half scale,
|
||||||
|
half2 (&z1)[2],
|
||||||
|
half2 (&y1)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half z = __int2half_rn(-((int)zero));
|
||||||
|
z = __hmul(z, scale);
|
||||||
|
z1[0] = __half2half2(z);
|
||||||
|
y1[0] = __half2half2(scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
half2(&z1)[2],
|
||||||
|
half2(&y1)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half z = __int2half_rn(-((int)zero));
|
||||||
|
z1[0] = __half2half2(z);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
half2 (&z1)[2],
|
||||||
|
half2 (&y1)[2],
|
||||||
|
int stride,
|
||||||
|
bool scaled
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half2 dqh2[8];
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||||
|
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||||
|
dqh2[i] = __halves2half2(d0, d1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scaled)
|
||||||
|
{
|
||||||
|
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||||
|
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||||
|
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||||
|
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||||
|
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||||
|
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||||
|
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#endif
|
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _qdq_util_cuh
|
||||||
|
#define _qdq_util_cuh
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
union half2_uint32
|
||||||
|
{
|
||||||
|
uint32_t as_uint32;
|
||||||
|
half2 as_half2;
|
||||||
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
union half_uint16
|
||||||
|
{
|
||||||
|
uint16_t as_uint16;
|
||||||
|
half as_half;
|
||||||
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
|
__device__ half_uint16(half val) : as_half(val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Max_scale premultiplied by 1/256
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||||
|
{
|
||||||
|
int qs_i = qs + 1;
|
||||||
|
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||||
|
qs_h = __hmul(qs_h, max_scale);
|
||||||
|
return qs_h;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||||
|
{
|
||||||
|
return __hmul(__int2half_rn(q - qzero), scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||||
|
{
|
||||||
|
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||||
|
return __int2half_rn(q - qzero);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||||
|
{
|
||||||
|
return (int)((q >> shift) & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||||
|
{
|
||||||
|
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
#endif
|
1
setup.py
1
setup.py
@ -219,6 +219,7 @@ vllm_extension_sources = [
|
|||||||
"csrc/activation_kernels.cu",
|
"csrc/activation_kernels.cu",
|
||||||
"csrc/layernorm_kernels.cu",
|
"csrc/layernorm_kernels.cu",
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||||
|
"csrc/quantization/gptq/q_gemm.cu",
|
||||||
"csrc/cuda_utils_kernels.cu",
|
"csrc/cuda_utils_kernels.cu",
|
||||||
"csrc/pybind.cpp",
|
"csrc/pybind.cpp",
|
||||||
]
|
]
|
||||||
|
@ -142,7 +142,7 @@ class ModelConfig:
|
|||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = ["awq", "squeezellm"]
|
supported_quantization = ["awq", "gptq", "squeezellm"]
|
||||||
rocm_not_supported_quantization = ["awq"]
|
rocm_not_supported_quantization = ["awq"]
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
self.quantization = self.quantization.lower()
|
self.quantization = self.quantization.lower()
|
||||||
|
@ -179,7 +179,7 @@ class EngineArgs:
|
|||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
type=str,
|
type=str,
|
||||||
choices=['awq', 'squeezellm', None],
|
choices=['awq', 'gptq', 'squeezellm', None],
|
||||||
default=None,
|
default=None,
|
||||||
help='Method used to quantize the weights')
|
help='Method used to quantize the weights')
|
||||||
return parser
|
return parser
|
||||||
|
@ -38,8 +38,9 @@ class LLM:
|
|||||||
However, if the `torch_dtype` in the config is `float32`, we will
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
use `float16` instead.
|
use `float16` instead.
|
||||||
quantization: The method used to quantize the model weights. Currently,
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
we support "awq". If None, we assume the model weights are not
|
we support "awq", "gptq" and "squeezellm". If None, we assume the
|
||||||
quantized and use `dtype` to determine the data type of the weights.
|
model weights are not quantized and use `dtype` to determine the
|
||||||
|
data type of the weights.
|
||||||
revision: The specific model version to use. It can be a branch name,
|
revision: The specific model version to use. It can be a branch name,
|
||||||
a tag name, or a commit id.
|
a tag name, or a commit id.
|
||||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
|
|||||||
"""Base class for different (maybe quantized) linear methods."""
|
"""Base class for different (maybe quantized) linear methods."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_weights(self, input_size: int, output_size: int,
|
def create_weights(self, input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
output_size_per_partition: int, input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
"""Create weights for a linear layer."""
|
"""Create weights for a linear layer."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, separate_bias_add: bool = False):
|
def __init__(self, separate_bias_add: bool = False):
|
||||||
self.separate_bias_add = separate_bias_add
|
self.separate_bias_add = separate_bias_add
|
||||||
|
|
||||||
def create_weights(self, input_size: int, output_size: int,
|
def create_weights(self, input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
output_size_per_partition: int, input_size: int,
|
||||||
weight = Parameter(torch.empty(output_size,
|
output_size: int,
|
||||||
input_size,
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
weight = Parameter(torch.empty(output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
dtype=params_dtype),
|
dtype=params_dtype),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -102,8 +106,10 @@ class ReplicatedLinear(torch.nn.Module):
|
|||||||
linear_method = UnquantizedLinearMethod()
|
linear_method = UnquantizedLinearMethod()
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.linear_weights = self.linear_method.create_weights(
|
self.linear_weights = self.linear_method.create_weights(
|
||||||
self.input_size, self.output_size, self.params_dtype)
|
self.input_size, self.output_size, self.input_size,
|
||||||
|
self.output_size, self.params_dtype)
|
||||||
for name, weight in self.linear_weights.items():
|
for name, weight in self.linear_weights.items():
|
||||||
|
if isinstance(weight, torch.Tensor):
|
||||||
self.register_parameter(name, weight)
|
self.register_parameter(name, weight)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(
|
self.bias = Parameter(
|
||||||
@ -168,8 +174,10 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
linear_method = UnquantizedLinearMethod()
|
linear_method = UnquantizedLinearMethod()
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.linear_weights = self.linear_method.create_weights(
|
self.linear_weights = self.linear_method.create_weights(
|
||||||
self.input_size, self.output_size_per_partition, self.params_dtype)
|
self.input_size, self.output_size_per_partition, self.input_size,
|
||||||
|
self.output_size, self.params_dtype)
|
||||||
for name, weight in self.linear_weights.items():
|
for name, weight in self.linear_weights.items():
|
||||||
|
if isinstance(weight, torch.Tensor):
|
||||||
self.register_parameter(name, weight)
|
self.register_parameter(name, weight)
|
||||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||||
if bias:
|
if bias:
|
||||||
@ -295,6 +303,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
else:
|
else:
|
||||||
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
|
if not ignore_warning:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Loading a weight without `output_dim` attribute in "
|
"Loading a weight without `output_dim` attribute in "
|
||||||
"MergedColumnParallelLinear, assume the weight is "
|
"MergedColumnParallelLinear, assume the weight is "
|
||||||
@ -418,6 +428,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
else:
|
else:
|
||||||
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
|
if not ignore_warning:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Loading a weight without `output_dim` attribute in "
|
"Loading a weight without `output_dim` attribute in "
|
||||||
"QKVParallelLinear, assume the weight is the same "
|
"QKVParallelLinear, assume the weight is the same "
|
||||||
@ -481,8 +493,10 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
linear_method = UnquantizedLinearMethod()
|
linear_method = UnquantizedLinearMethod()
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.linear_weights = self.linear_method.create_weights(
|
self.linear_weights = self.linear_method.create_weights(
|
||||||
self.input_size_per_partition, self.output_size, self.params_dtype)
|
self.input_size_per_partition, self.output_size, self.input_size,
|
||||||
|
self.output_size, self.params_dtype)
|
||||||
for name, weight in self.linear_weights.items():
|
for name, weight in self.linear_weights.items():
|
||||||
|
if isinstance(weight, torch.Tensor):
|
||||||
self.register_parameter(name, weight)
|
self.register_parameter(name, weight)
|
||||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
|
"gptq": GPTQConfig,
|
||||||
"squeezellm": SqueezeLLMConfig,
|
"squeezellm": SqueezeLLMConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: AWQConfig):
|
def __init__(self, quant_config: AWQConfig):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, input_size: int, output_size: int,
|
def create_weights(self, input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
output_size_per_partition: int, input_size: int,
|
||||||
if input_size % self.quant_config.group_size != 0:
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The input size is not aligned with the quantized "
|
"The input size is not aligned with the quantized "
|
||||||
"weight shape. This can be caused by too large "
|
"weight shape. This can be caused by too large "
|
||||||
"tensor parallel size.")
|
"tensor parallel size.")
|
||||||
if output_size % self.quant_config.pack_factor != 0:
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The output size is not aligned with the quantized "
|
"The output size is not aligned with the quantized "
|
||||||
"weight shape. This can be caused by too large "
|
"weight shape. This can be caused by too large "
|
||||||
@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
qweight = Parameter(
|
qweight = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
input_size,
|
input_size_per_partition,
|
||||||
output_size // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
})
|
})
|
||||||
qzeros = Parameter(
|
qzeros = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
input_size // self.quant_config.group_size,
|
input_size_per_partition // self.quant_config.group_size,
|
||||||
output_size // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
})
|
})
|
||||||
scales = Parameter(
|
scales = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
input_size // self.quant_config.group_size,
|
input_size_per_partition // self.quant_config.group_size,
|
||||||
output_size,
|
output_size_per_partition,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
),
|
),
|
||||||
@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def apply_weights(self,
|
def apply_weights(self,
|
||||||
weights: Dict[str, torch.Tensor],
|
weights: Dict[str, Any],
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
qweight = weights["qweight"]
|
qweight = weights["qweight"]
|
||||||
|
215
vllm/model_executor/layers/quantization/gptq.py
Normal file
215
vllm/model_executor/layers/quantization/gptq.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import enum
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm._C import ops
|
||||||
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQConfig(QuantizationConfig):
|
||||||
|
"""Config class for GPTQ.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2210.17323
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
desc_act: bool,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
# exllama kernel v1 only supports 4 bit
|
||||||
|
if self.weight_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 4-bit weight quantization is supported for "
|
||||||
|
f"GPTQ, but got {self.weight_bits} bits.")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||||
|
f"group_size={self.group_size}, "
|
||||||
|
f"desc_act={self.desc_act})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "gptq"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# Need to figure it out
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 60
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||||
|
return cls(weight_bits, group_size, desc_act)
|
||||||
|
|
||||||
|
def get_linear_method(self) -> "GPTQLinearMethod":
|
||||||
|
return GPTQLinearMethod(self)
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class ExllamaState(Enum):
|
||||||
|
|
||||||
|
UNUSED = enum.auto()
|
||||||
|
UNINITIALIZED = enum.auto()
|
||||||
|
READY = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQLinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for GPTQ.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The GPTQ quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: GPTQConfig):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_size_per_partition: int,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
del output_size # Unused.
|
||||||
|
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The input size is not aligned with the quantized "
|
||||||
|
"weight shape. This can be caused by too large "
|
||||||
|
"tensor parallel size.")
|
||||||
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The output size is not aligned with the quantized "
|
||||||
|
"weight shape. This can be caused by too large "
|
||||||
|
"tensor parallel size.")
|
||||||
|
|
||||||
|
if self.quant_config.group_size != -1:
|
||||||
|
group_size = self.quant_config.group_size
|
||||||
|
else:
|
||||||
|
group_size = input_size
|
||||||
|
exllama_state = ExllamaState.UNINITIALIZED
|
||||||
|
scale_and_zero_size = input_size // group_size
|
||||||
|
scale_and_zero_input_dim = None
|
||||||
|
if input_size != input_size_per_partition and self.quant_config.group_size != -1:
|
||||||
|
# For act-order models, we cannot use Exllama for row parallel layer
|
||||||
|
if self.quant_config.desc_act:
|
||||||
|
exllama_state = ExllamaState.UNUSED
|
||||||
|
else:
|
||||||
|
# we need to partition qzeros and scales for exllama kernel
|
||||||
|
scale_and_zero_size = input_size_per_partition // group_size
|
||||||
|
scale_and_zero_input_dim = 0
|
||||||
|
|
||||||
|
qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.quant_config.pack_factor,
|
||||||
|
output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qweight, {
|
||||||
|
"input_dim": 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 0,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
})
|
||||||
|
g_idx = Parameter(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
i // self.quant_config.group_size
|
||||||
|
for i in range(input_size_per_partition)
|
||||||
|
],
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||||
|
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
|
||||||
|
qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
scale_and_zero_size,
|
||||||
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qzeros, {
|
||||||
|
"input_dim": scale_and_zero_input_dim,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
})
|
||||||
|
scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
scale_and_zero_size,
|
||||||
|
output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(scales, {
|
||||||
|
"input_dim": scale_and_zero_input_dim,
|
||||||
|
"output_dim": 1,
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"qweight": qweight,
|
||||||
|
"g_idx": g_idx,
|
||||||
|
"qzeros": qzeros,
|
||||||
|
"scales": scales,
|
||||||
|
"exllama_state": exllama_state,
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
weights: Dict[str, Any],
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
qweight = weights["qweight"]
|
||||||
|
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
# exllama needs to shuffle the weight after the weight is loaded
|
||||||
|
# here we do the shuffle on first forward pass
|
||||||
|
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
|
||||||
|
if self.quant_config.desc_act:
|
||||||
|
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
|
||||||
|
torch.int)
|
||||||
|
else:
|
||||||
|
weights["g_idx"] = torch.empty((1, 1), device="meta")
|
||||||
|
weights["exllama_state"] = ExllamaState.READY
|
||||||
|
ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
|
||||||
|
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
|
||||||
|
weights["qzeros"], weights["scales"],
|
||||||
|
weights["g_idx"],
|
||||||
|
weights["exllama_state"] == ExllamaState.READY)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.reshape(out_shape)
|
@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
def __init__(self, quant_config: SqueezeLLMConfig):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, input_size: int, output_size: int,
|
def create_weights(self, input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
output_size_per_partition: int, input_size: int,
|
||||||
if input_size % self.quant_config.pack_factor != 0:
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The input size is not aligned with the quantized "
|
"The input size is not aligned with the quantized "
|
||||||
"weight shape. This can be caused by too large "
|
"weight shape. This can be caused by too large "
|
||||||
"tensor parallel size.")
|
"tensor parallel size.")
|
||||||
qweight = Parameter(
|
qweight = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
input_size // self.quant_config.pack_factor,
|
input_size_per_partition // self.quant_config.pack_factor,
|
||||||
output_size,
|
output_size_per_partition,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def apply_weights(self,
|
def apply_weights(self,
|
||||||
weights: Dict[str, torch.Tensor],
|
weights: Dict[str, Any],
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
qweight = weights["qweight"]
|
qweight = weights["qweight"]
|
||||||
|
@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if "word_embeddings" in name:
|
if "word_embeddings" in name:
|
||||||
name = name.replace(".word_embeddings", "")
|
name = name.replace(".word_embeddings", "")
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -425,16 +425,21 @@ class FalconForCausalLM(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format, revision):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
loaded_weight_shape = loaded_weight.shape
|
loaded_weight_shape = loaded_weight.shape
|
||||||
|
if output_dim is not None:
|
||||||
loaded_weight = loaded_weight.view(
|
loaded_weight = loaded_weight.view(
|
||||||
loaded_weight_shape[:output_dim] +
|
loaded_weight_shape[:output_dim] +
|
||||||
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
|
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||||
loaded_weight_shape[output_dim + 1:])
|
-1) + loaded_weight_shape[output_dim + 1:])
|
||||||
wq = loaded_weight.narrow(
|
wq = loaded_weight.narrow(
|
||||||
output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
|
output_dim + 1, 0,
|
||||||
|
num_query_heads_per_kv_head).reshape(
|
||||||
*loaded_weight_shape[:output_dim], -1,
|
*loaded_weight_shape[:output_dim], -1,
|
||||||
*loaded_weight_shape[output_dim + 1:])
|
*loaded_weight_shape[output_dim + 1:])
|
||||||
wk = loaded_weight.narrow(
|
wk = loaded_weight.narrow(
|
||||||
|
@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
if not name.endswith(".weight"):
|
if not name.endswith(".weight"):
|
||||||
continue
|
continue
|
||||||
loaded_weight = loaded_weight.t()
|
loaded_weight = loaded_weight.t()
|
||||||
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
linear_method=linear_method,
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||||
assert rotary_dim % 2 == 0
|
assert rotary_dim % 2 == 0
|
||||||
|
@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -153,7 +153,7 @@ class MixtralMoE(nn.Module):
|
|||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
self.num_total_experts,
|
self.num_total_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
linear_method=linear_method)
|
linear_method=None)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format, revision):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
# pylint: disable=E1136
|
# pylint: disable=E1136
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
|
|||||||
self.num_heads = (self.total_num_heads //
|
self.num_heads = (self.total_num_heads //
|
||||||
tensor_model_parallel_world_size)
|
tensor_model_parallel_world_size)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
|
||||||
self.c_attn = QKVParallelLinear(
|
self.c_attn = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module):
|
|||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -287,4 +287,5 @@ def initialize_dummy_weights(
|
|||||||
values between -1e-3 and 1e-3 works well for most models.
|
values between -1e-3 and 1e-3 works well for most models.
|
||||||
"""
|
"""
|
||||||
for param in model.state_dict().values():
|
for param in model.state_dict().values():
|
||||||
|
if torch.is_floating_point(param):
|
||||||
param.data.uniform_(low, high)
|
param.data.uniform_(low, high)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user