Support SqueezeLLM (#1326)
Co-authored-by: squeeze-ai-lab <squeezeailab.bair@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
bf31d3606a
commit
1f24755bf8
@ -70,7 +70,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', None],
|
||||
choices=['awq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
|
@ -201,7 +201,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', None],
|
||||
choices=['awq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n",
|
||||
|
@ -7,9 +7,13 @@ torch::Tensor awq_gemm(
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"awq_gemm",
|
||||
&awq_gemm,
|
||||
"Quantized GEMM for AWQ");
|
||||
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
}
|
||||
|
148
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
148
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
@ -0,0 +1,148 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// half-tensor
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
|
||||
namespace vllm {
|
||||
namespace squeezellm {
|
||||
|
||||
__device__ inline unsigned int as_unsigned(int i) {
|
||||
return *reinterpret_cast<unsigned int*>(&i);
|
||||
}
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
const half2* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
half2* __restrict__ mul,
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
int batch,
|
||||
int vec_height
|
||||
) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
int column_offset = col * 16;
|
||||
for (int val = 0; val < 16; val += 1) {
|
||||
int lut_index = column_offset + val;
|
||||
deq2[val][off] = lookup_table[lut_index];
|
||||
}
|
||||
|
||||
__half res;
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
|
||||
int i;
|
||||
int k;
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int lut_index1, lut_index2;
|
||||
|
||||
for (int b = 0; b < batch; ++b){
|
||||
i = width * row + col;
|
||||
res = __int2half_rd(0);
|
||||
k = 0;
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table
|
||||
) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||
(half2*) vec.data<at::Half>(),
|
||||
mat.data_ptr<int>(),
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
#undef BLOCKHEIGHT4
|
1
setup.py
1
setup.py
@ -200,6 +200,7 @@ quantization_extension = CUDAExtension(
|
||||
sources=[
|
||||
"csrc/quantization.cpp",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
|
@ -103,7 +103,7 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq"]
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
if self.quantization is None:
|
||||
return
|
||||
quantization = self.quantization.lower()
|
||||
|
@ -168,7 +168,7 @@ class EngineArgs:
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', None],
|
||||
choices=['awq', 'squeezellm', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
return parser
|
||||
|
@ -1,10 +1,14 @@
|
||||
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||
from vllm.model_executor.layers.quantized_linear.squeezellm import (
|
||||
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
_QUANTIZED_LINEAR_REGISTRY = {
|
||||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||
"squeezellm":
|
||||
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
|
||||
}
|
||||
|
||||
|
||||
|
@ -11,9 +11,11 @@ from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.input_size % self.quant_config.weight_bits == 0
|
||||
assert (self.output_size_per_partition %
|
||||
self.quant_config.pack_factor == 0)
|
||||
assert self.input_size % self.quant_config.group_size == 0
|
||||
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The tensor parallel size is not aligned with the quantized "
|
||||
"weight shape. Please use a different tensor parallel size.")
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size,
|
||||
@ -62,9 +64,11 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||
class AWQRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert (self.input_size_per_partition %
|
||||
self.quant_config.weight_bits == 0)
|
||||
assert self.output_size % self.quant_config.pack_factor == 0
|
||||
if self.input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The tensor parallel size is not aligned with the quantized "
|
||||
"weight shape. Please use a different tensor parallel size.")
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition,
|
||||
|
84
vllm/model_executor/layers/quantized_linear/squeezellm.py
Normal file
84
vllm/model_executor/layers/quantized_linear/squeezellm.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
|
||||
class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.input_size % self.quant_config.pack_factor == 0
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.pack_factor,
|
||||
self.output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lookup_table = Parameter(
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
|
||||
self.lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class SqueezeLLMRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The tensor parallel size is not aligned with the quantized "
|
||||
"weight shape. Please use a different tensor parallel size.")
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.pack_factor,
|
||||
self.output_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lookup_table = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
|
||||
self.lookup_table)
|
||||
return out.reshape(out_shape)
|
@ -314,17 +314,21 @@ class LlamaForCausalLM(nn.Module):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
weight_suffixes = ["weight"]
|
||||
col_weight_suffixes = ["weight"]
|
||||
row_weight_suffixes = ["weight"]
|
||||
else:
|
||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||
col_weight_suffixes = (
|
||||
self.quant_config.get_col_parallel_tensor_names())
|
||||
row_weight_suffixes = (
|
||||
self.quant_config.get_row_parallel_tensor_names())
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
for suffix in col_weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
for suffix in row_weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -351,10 +355,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_packed = False
|
||||
packed_dim = None
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
is_packed = self.quant_config.is_packed(name)
|
||||
packed_dim = self.quant_config.get_packed_dim(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
@ -368,9 +372,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if is_packed:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
if packed_dim is not None:
|
||||
shard_dim = 0 if not is_transposed else 1
|
||||
if packed_dim == shard_dim:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
if weight_name in ["k_proj", "v_proj"]:
|
||||
shard_id = tp_rank // num_kv_heads_replicas
|
||||
|
@ -298,17 +298,21 @@ class MistralForCausalLM(nn.Module):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
weight_suffixes = ["weight"]
|
||||
col_weight_suffixes = ["weight"]
|
||||
row_weight_suffixes = ["weight"]
|
||||
else:
|
||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||
col_weight_suffixes = (
|
||||
self.quant_config.get_col_parallel_tensor_names())
|
||||
row_weight_suffixes = (
|
||||
self.quant_config.get_row_parallel_tensor_names())
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
for suffix in col_weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
for suffix in row_weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -331,10 +335,10 @@ class MistralForCausalLM(nn.Module):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_packed = False
|
||||
packed_dim = None
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
is_packed = self.quant_config.is_packed(name)
|
||||
packed_dim = self.quant_config.get_packed_dim(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
@ -348,9 +352,11 @@ class MistralForCausalLM(nn.Module):
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if is_packed:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
if packed_dim is not None:
|
||||
shard_dim = 0 if not is_transposed else 1
|
||||
if packed_dim == shard_dim:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
|
@ -2,9 +2,11 @@ from typing import Type
|
||||
|
||||
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig
|
||||
|
||||
_QUANTIZATION_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
}
|
||||
|
||||
|
||||
|
@ -60,13 +60,17 @@ class AWQConfig(QuantizationConfig):
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros"]
|
||||
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||
return {"qweight": 1, "qzeros": 1}
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
||||
@classmethod
|
||||
def get_tp_tensor_names(cls) -> List[str]:
|
||||
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
||||
@classmethod
|
||||
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -45,19 +45,25 @@ class QuantizationConfig:
|
||||
"quantization config.")
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensor_names(cls) -> List[str]:
|
||||
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||
"""Returns a dictionary of packed tensor names and their pack dims."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_packed(cls, tensor_name: str) -> bool:
|
||||
"""Returns True if a tensor is packed.
|
||||
def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
|
||||
"""Returns the pack dim of a tensor if it is packed.
|
||||
|
||||
A tensor is considered packed if each element in the tensor is a
|
||||
packed representation of multiple elements in the original tensor.
|
||||
For example, an INT32 element in the tensor may represent 8 INT4
|
||||
elements in the original tensor.
|
||||
If the tensor is not packed, returns None.
|
||||
"""
|
||||
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
||||
packed_tensors = cls.get_packed_tensors()
|
||||
for packed_tensor_name, pack_dim in packed_tensors.items():
|
||||
if packed_tensor_name in tensor_name:
|
||||
return pack_dim
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
@ -71,5 +77,9 @@ class QuantizationConfig:
|
||||
for tag in cls.get_transposed_tensor_names())
|
||||
|
||||
@classmethod
|
||||
def get_tp_tensor_names(cls) -> List[str]:
|
||||
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
65
vllm/model_executor/quantization_utils/squeezellm.py
Normal file
65
vllm/model_executor/quantization_utils/squeezellm.py
Normal file
@ -0,0 +1,65 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
"""Config class for SqueezeLLM.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2306.07629
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
||||
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "squeezellm"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
return cls(weight_bits)
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||
return {"qweight": 0}
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight"]
|
||||
|
||||
@classmethod
|
||||
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "lookup_table"]
|
||||
|
||||
@classmethod
|
||||
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight"]
|
Loading…
x
Reference in New Issue
Block a user