diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 3c2a80da..578f897b 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -84,7 +84,7 @@ if __name__ == '__main__': parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 4540ed80..5d17cb3f 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -244,7 +244,7 @@ if __name__ == "__main__": parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", diff --git a/csrc/ops.h b/csrc/ops.h index 5ded1c12..9340a60d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -77,3 +77,15 @@ void squeezellm_gemm( torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); + +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama); + +void gptq_shuffle( + torch::Tensor q_weight, + torch::Tensor q_perm); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b0120e96..95f55768 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); #endif - - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); // Cache ops diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh new file mode 100644 index 00000000..4da0bc6e --- /dev/null +++ b/csrc/quantization/gptq/compat.cuh @@ -0,0 +1,64 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _compat_cuh +#define _compat_cuh + +namespace vllm { +namespace gptq { +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +} // namespace gptq +} // namespace vllm +#endif diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh new file mode 100644 index 00000000..1fdf019b --- /dev/null +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -0,0 +1,151 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +*/ + +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +} // namespace gptq +} // namespace vllm +#endif diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu new file mode 100644 index 00000000..6d070c65 --- /dev/null +++ b/csrc/quantization/gptq/q_gemm.cu @@ -0,0 +1,859 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include + +#include +#include +#include +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" +#include "qdq_4.cuh" + +namespace vllm { +namespace gptq { + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define MAX_Q_GEMM_ROWS 50 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#if defined(USE_ROCM) +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int* +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +{ + #if BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel; + #endif + return NULL; +} + + +void gemm_half_q_half_cuda_part +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + + kernel<<>> + ( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + c, + size_m, + size_n, + size_k, + groups, + b_q_perm + ); +} + + +__global__ void reconstruct_exllama_kernel +( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + + +void reconstruct_exllama +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* out, + int height, + int width, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + reconstruct_exllama_kernel<<>> + ( + b_q_weight, + b_q_perm, + b_gptq_qzeros, + b_gptq_scales, + height, + width, + groups, + out + ); +} + + +__global__ void gemm_half_q_half_alt_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width +) +{ + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + if (blockIdx.z == 0) + { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { + res2 = {}; + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + + +void gemm_half_q_half_alt +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + int size_m, + int size_n, + int size_k +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + gemm_half_q_half_alt_kernel<<>> + ( + (const half2*) a, + b_q_weight, + c, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + size_m, + size_k / 8, + size_n + ); +} + + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, + const int width, + const int group, + half* __restrict__ out +) +{ + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q4_row w_zeros_(w_zeros, group, width); + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + int group = g_idx[row + s / 4]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + + +void reconstruct_gptq +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* out, + int height, + int width, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 8); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_gptq_kernel<<>> + ( + b_q_weight, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + height, + width, + groups, + out + ); +} + + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + half* temp_dq, + int size_m, + int size_n, + int size_k, + int groups, + bool use_exllama +) +{ + if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) { + // Reconstruct FP16 matrix, then cuBLAS + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups); + } + else + { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups); + } + + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha, temp_dq, size_n, + a, size_k, + &beta, c, size_n); + } + else if (use_exllama) + { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) + { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, + groups); + } + + if (last_chunk_size) + { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, + b_gptq_scales, b_g_idx, c + last_chunk * size_n, + last_chunk_size, size_n, size_k, last_chunk_size, + groups); + } + } + else + { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k); + } +} + + +__global__ void shuffle_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +} + + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + + +void shuffle_exllama_weight +( + uint32_t* q_weight, + int* q_perm, + int height, + int width +) +{ + if (q_perm) + { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t)); + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + q_weight, + new_qweight, + q_perm, + height / 8, + width + ); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + shuffle_kernel<<>>(q_weight, height, width); +} + +} // namespace gptq +} // namespace vllm + +torch::Tensor gptq_gemm +( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda + ( + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) c.data_ptr(), + (half*) temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama + ); + return c; +} + +void gptq_shuffle +( + torch::Tensor q_weight, + torch::Tensor q_perm +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_weight.size(0) * 8, + q_weight.size(1) + ); +} diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh new file mode 100644 index 00000000..cfc4635a --- /dev/null +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -0,0 +1,235 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} +} // namespace gptq +} // namespace vllm + +#else + +namespace vllm { +namespace gptq { +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1)[2], + half2 (&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z = __hmul(z, scale); + z1[0] = __half2half2(z); + y1[0] = __half2half2(scale); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1)[2], + half2(&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z1[0] = __half2half2(z); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1)[2], + half2 (&y1)[2], + int stride, + bool scaled +) +{ + half2 dqh2[8]; + + uint32_t qa = q_0; + for (int i = 0; i < 4; i++) + { + half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; + half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; + dqh2[i] = __halves2half2(d0, d1); + } + + if (scaled) + { + dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); + dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); + dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); + dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); + } + else + { + dq[0] = __hadd2(dqh2[0], z1[0]); + dq[1] = __hadd2(dqh2[1], z1[0]); + dq[2] = __hadd2(dqh2[2], z1[0]); + dq[3] = __hadd2(dqh2[3], z1[0]); + } +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh new file mode 100644 index 00000000..1722a9aa --- /dev/null +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -0,0 +1,60 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +namespace vllm { +namespace gptq { + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +} // namespace gptq +} // namespace vllm +#endif diff --git a/setup.py b/setup.py index 2edf0b6a..811d494e 100644 --- a/setup.py +++ b/setup.py @@ -219,6 +219,7 @@ vllm_extension_sources = [ "csrc/activation_kernels.cu", "csrc/layernorm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", "csrc/pybind.cpp", ] diff --git a/vllm/config.py b/vllm/config.py index eb1fee0f..15020e79 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -142,7 +142,7 @@ class ModelConfig: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "squeezellm"] + supported_quantization = ["awq", "gptq", "squeezellm"] rocm_not_supported_quantization = ["awq"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8dec696e..2cc69092 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -179,7 +179,7 @@ class EngineArgs: parser.add_argument('--quantization', '-q', type=str, - choices=['awq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None, help='Method used to quantize the weights') return parser diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b05ba71c..75edd49b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -38,8 +38,9 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq". If None, we assume the model weights are not - quantized and use `dtype` to determine the data type of the weights. + we support "awq", "gptq" and "squeezellm". If None, we assume the + model weights are not quantized and use `dtype` to determine the + data type of the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 810efb67..5190de65 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -21,8 +21,10 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" raise NotImplementedError @@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - weight = Parameter(torch.empty(output_size, - input_size, + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) @@ -102,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module): linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.params_dtype) + self.input_size, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) if bias: self.bias = Parameter( torch.empty(self.output_size, @@ -168,10 +174,12 @@ class ColumnParallelLinear(torch.nn.Module): linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.params_dtype) + self.input_size, self.output_size_per_partition, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -295,10 +303,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) else: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -418,10 +428,12 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) else: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "QKVParallelLinear, assume the weight is the same " - "for all partitions.") + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -481,10 +493,12 @@ class RowParallelLinear(torch.nn.Module): linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.params_dtype) + self.input_size_per_partition, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3d937ba6..b3449eaf 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,11 +1,13 @@ from typing import Type -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, + "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, } diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 95d419e6..831576b1 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.group_size != 0: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size % self.quant_config.pack_factor != 0: + if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " @@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase): qweight = Parameter( torch.empty( - input_size, - output_size // self.quant_config.pack_factor, + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase): }) qzeros = Parameter( torch.empty( - input_size // self.quant_config.group_size, - output_size // self.quant_config.pack_factor, + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase): }) scales = Parameter( torch.empty( - input_size // self.quant_config.group_size, - output_size, + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, device="cuda", dtype=params_dtype, ), @@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase): } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py new file mode 100644 index 00000000..8fe96e7d --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -0,0 +1,215 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.pack_factor = 32 // self.weight_bits + # exllama kernel v1 only supports 4 bit + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + return cls(weight_bits, group_size, desc_act) + + def get_linear_method(self) -> "GPTQLinearMethod": + return GPTQLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() + + +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. + + Args: + quant_config: The GPTQ quantization config. + """ + + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + del output_size # Unused. + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExllamaState.UNINITIALIZED + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if input_size != input_size_per_partition and self.quant_config.group_size != -1: + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + exllama_state = ExllamaState.UNUSED + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + g_idx = Parameter( + torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) + qzeros = Parameter( + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + scales = Parameter( + torch.empty( + scale_and_zero_size, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 1, + }) + return { + "qweight": qweight, + "g_idx": g_idx, + "qzeros": qzeros, + "scales": scales, + "exllama_state": exllama_state, + } + + def apply_weights(self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + out_shape = x.shape[:-1] + (qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + weights["g_idx"] = torch.argsort(weights["g_idx"]).to( + torch.int) + else: + weights["g_idx"] = torch.empty((1, 1), device="meta") + weights["exllama_state"] = ExllamaState.READY + ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) + output = ops.gptq_gemm(reshaped_x, weights["qweight"], + weights["qzeros"], weights["scales"], + weights["g_idx"], + weights["exllama_state"] == ExllamaState.READY) + if bias is not None: + output = output + bias + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index f2f9cac6..1932bd14 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.pack_factor != 0: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") qweight = Parameter( torch.empty( - input_size // self.quant_config.pack_factor, - output_size, + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, device="cuda", dtype=torch.int32, ), @@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase): } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index f8c4d643..90a4edd4 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 4a07c395..3ae346b4 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 60ec4d9b..a778e552 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module): continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8890d29b..34e71de0 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -425,27 +425,32 @@ class FalconForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] if "query_key_value" in name: output_dim = getattr(param, "output_dim", None) loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + - (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + - loaded_weight_shape[output_dim + 1:]) - wq = loaded_weight.narrow( - output_dim + 1, 0, num_query_heads_per_kv_head).reshape( - *loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - wk = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - wv = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head + 1, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 5fe678ec..6ebfdf3d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module): if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 1ad344fd..49dfbcd0 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index df5c86bf..ac8b2039 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module): config.hidden_size, linear_method=linear_method, ) - scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index ba28ff8d..4102a08d 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index cc83f5dd..d6722808 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 6e67d8b8..576d9674 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b11e3713..a3f1582a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -153,7 +153,7 @@ class MixtralMoE(nn.Module): self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, - linear_method=linear_method) + linear_method=None) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 22b2a5c6..dd4fa547 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 1c698c20..3f9b3385 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index ac441e47..0c9671e6 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module): if "rotary_emb.inv_freq" in name: continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue # pylint: disable=E1136 param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 33bae61f..f41a3cc4 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -82,7 +82,6 @@ class QWenAttention(nn.Module): self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - self.c_attn = QKVParallelLinear( hidden_size, self.head_dim, @@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 889cc3f0..d16f21c0 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 825b9016..36ad0f38 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -287,4 +287,5 @@ def initialize_dummy_weights( values between -1e-3 and 1e-3 works well for most models. """ for param in model.state_dict().values(): - param.data.uniform_(low, high) + if torch.is_floating_point(param): + param.data.uniform_(low, high)