[Misc] Remove SqueezeLLM
(#8220)
This commit is contained in:
parent
9db52eab3d
commit
23f322297f
@ -181,7 +181,6 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/pos_encoding_kernels.cu"
|
"csrc/pos_encoding_kernels.cu"
|
||||||
"csrc/activation_kernels.cu"
|
"csrc/activation_kernels.cu"
|
||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
|
@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
|||||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor& scales);
|
torch::Tensor& scales);
|
||||||
|
|
||||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table);
|
|
||||||
|
|
||||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
|
@ -1,216 +0,0 @@
|
|||||||
#include <torch/all.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
// half-tensor
|
|
||||||
#include <c10/cuda/CUDAStream.h>
|
|
||||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#define BLOCKWIDTH 128
|
|
||||||
#define BLOCKHEIGHT4 16
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
namespace squeezellm {
|
|
||||||
|
|
||||||
__device__ inline unsigned int as_unsigned(int i) {
|
|
||||||
return *reinterpret_cast<unsigned int*>(&i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4-bit matvec kernel (LUT-based)
|
|
||||||
__global__ void NUQ4MatMulKernel(
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
const half2* __restrict__ vec,
|
|
||||||
#else
|
|
||||||
const __half2* __restrict__ vec,
|
|
||||||
#endif
|
|
||||||
const int* __restrict__ mat,
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
half2* __restrict__ mul,
|
|
||||||
#else
|
|
||||||
float2* __restrict__ mul,
|
|
||||||
#endif
|
|
||||||
const __half* __restrict__ lookup_table, int height, int width, int batch,
|
|
||||||
int vec_height) {
|
|
||||||
|
|
||||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
|
||||||
|
|
||||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
|
||||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
__shared__ half2 blockvec[blockwidth2];
|
|
||||||
#else
|
|
||||||
__shared__ __half2 blockvec[blockwidth2];
|
|
||||||
#endif
|
|
||||||
|
|
||||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
|
||||||
int off = threadIdx.x;
|
|
||||||
int column_offset = col * 16;
|
|
||||||
for (int val = 0; val < 16; val += 1) {
|
|
||||||
int lut_index = column_offset + val;
|
|
||||||
deq2[val][off] = lookup_table[lut_index];
|
|
||||||
}
|
|
||||||
|
|
||||||
__half res;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
half2 res2;
|
|
||||||
half2 tmp2;
|
|
||||||
#else
|
|
||||||
__half2 res2;
|
|
||||||
__half2 tmp2;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
int i;
|
|
||||||
int k;
|
|
||||||
|
|
||||||
unsigned int tmp1;
|
|
||||||
unsigned int lut_index1, lut_index2;
|
|
||||||
|
|
||||||
for (int b = 0; b < batch; ++b) {
|
|
||||||
i = width * row + col;
|
|
||||||
res = __int2half_rd(0);
|
|
||||||
k = 0;
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x < blockwidth2)
|
|
||||||
blockvec[threadIdx.x] =
|
|
||||||
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
|
|
||||||
threadIdx.x];
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
while (k < blockwidth2) {
|
|
||||||
tmp1 = as_unsigned(mat[i]);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
res2 = {};
|
|
||||||
tmp2 = {};
|
|
||||||
#else
|
|
||||||
res2.x = __half_as_ushort(__float2half(0));
|
|
||||||
res2.y = __half_as_ushort(__float2half(0));
|
|
||||||
tmp2.x = __half_as_ushort(__float2half(0));
|
|
||||||
tmp2.y = __half_as_ushort(__float2half(0));
|
|
||||||
#endif
|
|
||||||
|
|
||||||
lut_index1 = tmp1 & 0xF;
|
|
||||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
tmp2.x = deq2[lut_index1][off];
|
|
||||||
tmp2.y = deq2[lut_index2][off];
|
|
||||||
#else
|
|
||||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
|
||||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
|
||||||
#endif
|
|
||||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
|
||||||
|
|
||||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
|
||||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
tmp2.x = deq2[lut_index1][off];
|
|
||||||
tmp2.y = deq2[lut_index2][off];
|
|
||||||
#else
|
|
||||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
|
||||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
|
||||||
#endif
|
|
||||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
|
||||||
|
|
||||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
|
||||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
tmp2.x = deq2[lut_index1][off];
|
|
||||||
tmp2.y = deq2[lut_index2][off];
|
|
||||||
#else
|
|
||||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
|
||||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
|
||||||
#endif
|
|
||||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
|
||||||
|
|
||||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
|
||||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
tmp2.x = deq2[lut_index1][off];
|
|
||||||
tmp2.y = deq2[lut_index2][off];
|
|
||||||
#else
|
|
||||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
|
||||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
|
||||||
#endif
|
|
||||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
|
||||||
#else
|
|
||||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
|
|
||||||
res);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
i += width;
|
|
||||||
k += 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// col%2 -> only set one of the two values
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
half2 res3 = {};
|
|
||||||
if (col % 2 == 0) {
|
|
||||||
res3.x = res;
|
|
||||||
} else {
|
|
||||||
res3.y = res;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
__half2 res3;
|
|
||||||
res3.x = __half_as_ushort(__float2half(0));
|
|
||||||
res3.y = __half_as_ushort(__float2half(0));
|
|
||||||
if (col % 2 == 0) {
|
|
||||||
res3.x = __half_as_ushort(res);
|
|
||||||
} else {
|
|
||||||
res3.y = __half_as_ushort(res);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
|
||||||
#else
|
|
||||||
int tmp_addr = b * width / 2 + col / 2;
|
|
||||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
|
||||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace squeezellm
|
|
||||||
} // namespace vllm
|
|
||||||
|
|
||||||
// 4-bit matvec kernel (LUT-based)
|
|
||||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table) {
|
|
||||||
int height = mat.size(0);
|
|
||||||
int width = mat.size(1);
|
|
||||||
|
|
||||||
int batch = vec.size(0);
|
|
||||||
int vec_height = vec.size(1);
|
|
||||||
|
|
||||||
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
|
||||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
|
|
||||||
dim3 threads(BLOCKWIDTH);
|
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
(half2*)vec.data_ptr<at::Half>(),
|
|
||||||
#else
|
|
||||||
(__half2*)vec.data_ptr<at::Half>(),
|
|
||||||
#endif
|
|
||||||
mat.data_ptr<int>(),
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
(half2*)mul.data_ptr<at::Half>(),
|
|
||||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
|
||||||
#else
|
|
||||||
(float2*)mul.data_ptr<float>(),
|
|
||||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
|
||||||
#endif
|
|
||||||
height, width, batch, vec_height);
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef BLOCKWIDTH
|
|
||||||
#undef BLOCKHEIGHT4
|
|
@ -237,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||||
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||||
|
|
||||||
// Quantized GEMM for SqueezeLLM.
|
|
||||||
ops.def(
|
|
||||||
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
|
|
||||||
"lookup_table) -> ()");
|
|
||||||
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
|
|
||||||
|
|
||||||
// Compute FP8 quantized tensor for given scaling factor.
|
// Compute FP8 quantized tensor for given scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
||||||
|
@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations
|
|||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
* - SqueezeLLM
|
|
||||||
- ✅︎
|
|
||||||
- ✅︎
|
|
||||||
- ✅︎
|
|
||||||
- ✅︎
|
|
||||||
- ✅︎
|
|
||||||
- ✗
|
|
||||||
- ✗
|
|
||||||
- ✗
|
|
||||||
- ✗
|
|
||||||
- ✗
|
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
^^^^^^
|
^^^^^^
|
||||||
|
@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various
|
|||||||
|
|
||||||
python3 benchmarks/benchmark_throughput.py --help
|
python3 benchmarks/benchmark_throughput.py --help
|
||||||
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
[--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||||
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||||
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||||
[--quantization-param-path KV_CACHE_quantization_param_path]
|
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||||
@ -76,7 +76,7 @@ optional arguments:
|
|||||||
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||||
--model MODEL
|
--model MODEL
|
||||||
--tokenizer TOKENIZER
|
--tokenizer TOKENIZER
|
||||||
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None}
|
--quantization {awq,gptq,None}, -q {awq,gptq,None}
|
||||||
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||||
--n N Number of generated sequences per prompt.
|
--n N Number of generated sequences per prompt.
|
||||||
--use-beam-search
|
--use-beam-search
|
||||||
|
@ -209,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
|
|||||||
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
|
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
|
||||||
|
|
||||||
|
|
||||||
# squeezellm
|
|
||||||
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
|
|
||||||
lookup_table: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
|
|
||||||
|
|
||||||
|
|
||||||
# marlin
|
# marlin
|
||||||
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
|
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
|
||||||
|
@ -277,7 +277,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = [*QUANTIZATION_METHODS]
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"]
|
rocm_supported_quantization = ["awq", "gptq", "fp8"]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
|
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
|
||||||
@ -1537,7 +1537,7 @@ class LoRAConfig:
|
|||||||
if model_config.quantization and model_config.quantization not in [
|
if model_config.quantization and model_config.quantization not in [
|
||||||
"awq", "gptq"
|
"awq", "gptq"
|
||||||
]:
|
]:
|
||||||
# TODO support marlin and squeezellm
|
# TODO support marlin
|
||||||
logger.warning("%s quantization is not tested with LoRA yet.",
|
logger.warning("%s quantization is not tested with LoRA yet.",
|
||||||
model_config.quantization)
|
model_config.quantization)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class LLM:
|
|||||||
However, if the `torch_dtype` in the config is `float32`, we will
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
use `float16` instead.
|
use `float16` instead.
|
||||||
quantization: The method used to quantize the model weights. Currently,
|
quantization: The method used to quantize the model weights. Currently,
|
||||||
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
|
we support "awq", "gptq", and "fp8" (experimental).
|
||||||
If None, we first check the `quantization_config` attribute in the
|
If None, we first check the `quantization_config` attribute in the
|
||||||
model config file. If that is None, we assume the model weights are
|
model config file. If that is None, we assume the model weights are
|
||||||
not quantized and use `dtype` to determine the data type of
|
not quantized and use `dtype` to determine the data type of
|
||||||
|
@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
|||||||
# unquantizedLinear
|
# unquantizedLinear
|
||||||
if hasattr(base_layer, "weight"):
|
if hasattr(base_layer, "weight"):
|
||||||
return base_layer.weight.device
|
return base_layer.weight.device
|
||||||
# GPTQ/AWQ/SqueezeLLM
|
# GPTQ/AWQ
|
||||||
elif hasattr(base_layer, "qweight"):
|
elif hasattr(base_layer, "qweight"):
|
||||||
return base_layer.qweight.device
|
return base_layer.qweight.device
|
||||||
# marlin
|
# marlin
|
||||||
|
@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|||||||
from vllm.model_executor.layers.quantization.neuron_quant import (
|
from vllm.model_executor.layers.quantization.neuron_quant import (
|
||||||
NeuronQuantConfig)
|
NeuronQuantConfig)
|
||||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
|
||||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
@ -43,7 +42,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"gptq_marlin": GPTQMarlinConfig,
|
"gptq_marlin": GPTQMarlinConfig,
|
||||||
"awq_marlin": AWQMarlinConfig,
|
"awq_marlin": AWQMarlinConfig,
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"squeezellm": SqueezeLLMConfig,
|
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
"qqq": QQQConfig,
|
"qqq": QQQConfig,
|
||||||
|
@ -1,138 +0,0 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.utils import is_hip
|
|
||||||
|
|
||||||
|
|
||||||
class SqueezeLLMConfig(QuantizationConfig):
|
|
||||||
"""Config class for SqueezeLLM.
|
|
||||||
|
|
||||||
Reference: https://arxiv.org/pdf/2306.07629
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight_bits: int,
|
|
||||||
) -> None:
|
|
||||||
self.weight_bits = weight_bits
|
|
||||||
|
|
||||||
if self.weight_bits != 4:
|
|
||||||
raise ValueError(
|
|
||||||
"Currently, only 4-bit weight quantization is supported for "
|
|
||||||
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
|
||||||
|
|
||||||
self.pack_factor = 32 // self.weight_bits
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "squeezellm"
|
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
|
||||||
return [torch.half]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
return 70
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_config_filenames() -> List[str]:
|
|
||||||
return ["quant_config.json"]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
|
||||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
|
||||||
return cls(weight_bits)
|
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
|
||||||
prefix: str) -> Optional[QuantizeMethodBase]:
|
|
||||||
if isinstance(layer, LinearBase):
|
|
||||||
return SqueezeLLMLinearMethod(self)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class SqueezeLLMLinearMethod(QuantizeMethodBase):
|
|
||||||
"""Linear method for SqueezeLLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quant_config: The SqueezeLLM quantization config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int], input_size: int,
|
|
||||||
output_size: int, params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs):
|
|
||||||
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"The input size is not aligned with the quantized "
|
|
||||||
"weight shape. This can be caused by too large "
|
|
||||||
"tensor parallel size.")
|
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
|
||||||
qweight = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
input_size_per_partition // self.quant_config.pack_factor,
|
|
||||||
output_size_per_partition,
|
|
||||||
dtype=torch.int32,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
qweight, {
|
|
||||||
"input_dim": 0,
|
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 0,
|
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
|
||||||
})
|
|
||||||
lookup_table = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
output_size,
|
|
||||||
self.quant_config.weight_bits**2,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
set_weight_attrs(lookup_table, {
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
layer.register_parameter("qweight", qweight)
|
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
|
||||||
layer.register_parameter("lookup_table", lookup_table)
|
|
||||||
set_weight_attrs(lookup_table, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
qweight = layer.qweight
|
|
||||||
lookup_table = layer.lookup_table
|
|
||||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
|
||||||
if is_hip():
|
|
||||||
out_f = torch.zeros(out_shape, dtype=torch.float)
|
|
||||||
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
|
|
||||||
out = out_f.to(dtype=torch.float16)
|
|
||||||
else:
|
|
||||||
# NOTE: The output tensor should be zero-initialized.
|
|
||||||
out = torch.zeros(out_shape, dtype=torch.float16)
|
|
||||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
out.add_(bias)
|
|
||||||
return out.reshape(out_shape)
|
|
Loading…
x
Reference in New Issue
Block a user