Support SqueezeLLM (#1326)
Co-authored-by: squeeze-ai-lab <squeezeailab.bair@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
bf31d3606a
commit
1f24755bf8
@ -70,7 +70,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', None],
|
choices=['awq', 'squeezellm', None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
|
@ -201,7 +201,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', None],
|
choices=['awq', 'squeezellm', None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
|
@ -7,9 +7,13 @@ torch::Tensor awq_gemm(
|
|||||||
torch::Tensor _zeros,
|
torch::Tensor _zeros,
|
||||||
int split_k_iters);
|
int split_k_iters);
|
||||||
|
|
||||||
|
void squeezellm_gemm(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor lookup_table);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
"awq_gemm",
|
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
&awq_gemm,
|
|
||||||
"Quantized GEMM for AWQ");
|
|
||||||
}
|
}
|
||||||
|
148
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
148
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
#include <torch/all.h>
|
||||||
|
#include <torch/python.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
// half-tensor
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||||
|
|
||||||
|
#define BLOCKWIDTH 128
|
||||||
|
#define BLOCKHEIGHT4 16
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace squeezellm {
|
||||||
|
|
||||||
|
__device__ inline unsigned int as_unsigned(int i) {
|
||||||
|
return *reinterpret_cast<unsigned int*>(&i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4-bit matvec kernel (LUT-based)
|
||||||
|
__global__ void NUQ4MatMulKernel(
|
||||||
|
const half2* __restrict__ vec,
|
||||||
|
const int* __restrict__ mat,
|
||||||
|
half2* __restrict__ mul,
|
||||||
|
const __half* __restrict__ lookup_table,
|
||||||
|
int height,
|
||||||
|
int width,
|
||||||
|
int batch,
|
||||||
|
int vec_height
|
||||||
|
) {
|
||||||
|
|
||||||
|
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||||
|
|
||||||
|
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||||
|
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||||
|
|
||||||
|
__shared__ half2 blockvec[blockwidth2];
|
||||||
|
|
||||||
|
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||||
|
int off = threadIdx.x;
|
||||||
|
int column_offset = col * 16;
|
||||||
|
for (int val = 0; val < 16; val += 1) {
|
||||||
|
int lut_index = column_offset + val;
|
||||||
|
deq2[val][off] = lookup_table[lut_index];
|
||||||
|
}
|
||||||
|
|
||||||
|
__half res;
|
||||||
|
half2 res2;
|
||||||
|
half2 tmp2;
|
||||||
|
|
||||||
|
int i;
|
||||||
|
int k;
|
||||||
|
|
||||||
|
unsigned int tmp1;
|
||||||
|
unsigned int lut_index1, lut_index2;
|
||||||
|
|
||||||
|
for (int b = 0; b < batch; ++b){
|
||||||
|
i = width * row + col;
|
||||||
|
res = __int2half_rd(0);
|
||||||
|
k = 0;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.x < blockwidth2)
|
||||||
|
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
while (k < blockwidth2) {
|
||||||
|
tmp1 = as_unsigned(mat[i]);
|
||||||
|
|
||||||
|
res2 = {};
|
||||||
|
tmp2 = {};
|
||||||
|
|
||||||
|
lut_index1 = tmp1 & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||||
|
|
||||||
|
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||||
|
|
||||||
|
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||||
|
|
||||||
|
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||||
|
|
||||||
|
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||||
|
|
||||||
|
i += width;
|
||||||
|
k += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// col%2 -> only set one of the two values
|
||||||
|
half2 res3 = {};
|
||||||
|
if (col % 2 == 0) {
|
||||||
|
res3.x = res;
|
||||||
|
} else {
|
||||||
|
res3.y = res;
|
||||||
|
}
|
||||||
|
|
||||||
|
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace squeezellm
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// 4-bit matvec kernel (LUT-based)
|
||||||
|
void squeezellm_gemm(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor lookup_table
|
||||||
|
) {
|
||||||
|
int height = mat.size(0);
|
||||||
|
int width = mat.size(1);
|
||||||
|
|
||||||
|
int batch = vec.size(0);
|
||||||
|
int vec_height = vec.size(1);
|
||||||
|
|
||||||
|
dim3 blocks(
|
||||||
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||||
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||||
|
);
|
||||||
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
|
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||||
|
(half2*) vec.data<at::Half>(),
|
||||||
|
mat.data_ptr<int>(),
|
||||||
|
(half2*) mul.data<at::Half>(),
|
||||||
|
(__half*) lookup_table.data<at::Half>(),
|
||||||
|
height, width, batch, vec_height
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef BLOCKWIDTH
|
||||||
|
#undef BLOCKHEIGHT4
|
1
setup.py
1
setup.py
@ -200,6 +200,7 @@ quantization_extension = CUDAExtension(
|
|||||||
sources=[
|
sources=[
|
||||||
"csrc/quantization.cpp",
|
"csrc/quantization.cpp",
|
||||||
"csrc/quantization/awq/gemm_kernels.cu",
|
"csrc/quantization/awq/gemm_kernels.cu",
|
||||||
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||||
],
|
],
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": CXX_FLAGS,
|
"cxx": CXX_FLAGS,
|
||||||
|
@ -103,7 +103,7 @@ class ModelConfig:
|
|||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = ["awq"]
|
supported_quantization = ["awq", "squeezellm"]
|
||||||
if self.quantization is None:
|
if self.quantization is None:
|
||||||
return
|
return
|
||||||
quantization = self.quantization.lower()
|
quantization = self.quantization.lower()
|
||||||
|
@ -168,7 +168,7 @@ class EngineArgs:
|
|||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
type=str,
|
type=str,
|
||||||
choices=['awq', None],
|
choices=['awq', 'squeezellm', None],
|
||||||
default=None,
|
default=None,
|
||||||
help='Method used to quantize the weights')
|
help='Method used to quantize the weights')
|
||||||
return parser
|
return parser
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
from vllm.model_executor.layers.quantized_linear.awq import (
|
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||||
AWQColumnParallelLinear, AWQRowParallelLinear)
|
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.quantized_linear.squeezellm import (
|
||||||
|
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
|
||||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
|
||||||
_QUANTIZED_LINEAR_REGISTRY = {
|
_QUANTIZED_LINEAR_REGISTRY = {
|
||||||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||||
|
"squeezellm":
|
||||||
|
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,9 +11,11 @@ from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
|||||||
class AWQColumnParallelLinear(ColumnParallelLinear):
|
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
def create_weights(self, dtype: torch.dtype) -> None:
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
assert self.input_size % self.quant_config.weight_bits == 0
|
assert self.input_size % self.quant_config.group_size == 0
|
||||||
assert (self.output_size_per_partition %
|
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
self.quant_config.pack_factor == 0)
|
raise ValueError(
|
||||||
|
"The tensor parallel size is not aligned with the quantized "
|
||||||
|
"weight shape. Please use a different tensor parallel size.")
|
||||||
self.qweight = Parameter(
|
self.qweight = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
self.input_size,
|
self.input_size,
|
||||||
@ -62,9 +64,11 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
|
|||||||
class AWQRowParallelLinear(RowParallelLinear):
|
class AWQRowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
def create_weights(self, dtype: torch.dtype) -> None:
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
assert (self.input_size_per_partition %
|
|
||||||
self.quant_config.weight_bits == 0)
|
|
||||||
assert self.output_size % self.quant_config.pack_factor == 0
|
assert self.output_size % self.quant_config.pack_factor == 0
|
||||||
|
if self.input_size_per_partition % self.quant_config.group_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The tensor parallel size is not aligned with the quantized "
|
||||||
|
"weight shape. Please use a different tensor parallel size.")
|
||||||
self.qweight = Parameter(
|
self.qweight = Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
self.input_size_per_partition,
|
self.input_size_per_partition,
|
||||||
|
84
vllm/model_executor/layers/quantized_linear/squeezellm.py
Normal file
84
vllm/model_executor/layers/quantized_linear/squeezellm.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import quantization_ops
|
||||||
|
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert self.input_size % self.quant_config.pack_factor == 0
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.pack_factor,
|
||||||
|
self.output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.lookup_table = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.output_size_per_partition,
|
||||||
|
self.quant_config.weight_bits**2,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
# NOTE: The output tensor should be zero-initialized.
|
||||||
|
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||||
|
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
|
||||||
|
self.lookup_table)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeLLMRowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The tensor parallel size is not aligned with the quantized "
|
||||||
|
"weight shape. Please use a different tensor parallel size.")
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.pack_factor,
|
||||||
|
self.output_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.lookup_table = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.output_size,
|
||||||
|
self.quant_config.weight_bits**2,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
|
||||||
|
# NOTE: The output tensor should be zero-initialized.
|
||||||
|
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||||
|
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
|
||||||
|
self.lookup_table)
|
||||||
|
return out.reshape(out_shape)
|
@ -314,17 +314,21 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
revision: Optional[str] = None):
|
revision: Optional[str] = None):
|
||||||
if self.quant_config is None:
|
if self.quant_config is None:
|
||||||
weight_suffixes = ["weight"]
|
col_weight_suffixes = ["weight"]
|
||||||
|
row_weight_suffixes = ["weight"]
|
||||||
else:
|
else:
|
||||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
col_weight_suffixes = (
|
||||||
|
self.quant_config.get_col_parallel_tensor_names())
|
||||||
|
row_weight_suffixes = (
|
||||||
|
self.quant_config.get_row_parallel_tensor_names())
|
||||||
|
|
||||||
column_parallel_weights: List[str] = []
|
column_parallel_weights: List[str] = []
|
||||||
for layer in self._column_parallel_layers:
|
for layer in self._column_parallel_layers:
|
||||||
for suffix in weight_suffixes:
|
for suffix in col_weight_suffixes:
|
||||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
row_parallel_weights: List[str] = []
|
row_parallel_weights: List[str] = []
|
||||||
for layer in self._row_parallel_layers:
|
for layer in self._row_parallel_layers:
|
||||||
for suffix in weight_suffixes:
|
for suffix in row_weight_suffixes:
|
||||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -351,10 +355,10 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
is_packed = False
|
packed_dim = None
|
||||||
is_transposed = False
|
is_transposed = False
|
||||||
if self.quant_config is not None:
|
if self.quant_config is not None:
|
||||||
is_packed = self.quant_config.is_packed(name)
|
packed_dim = self.quant_config.get_packed_dim(name)
|
||||||
is_transposed = self.quant_config.is_transposed(name)
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
@ -368,7 +372,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if is_transposed:
|
if is_transposed:
|
||||||
param = param.T
|
param = param.T
|
||||||
|
|
||||||
if is_packed:
|
if packed_dim is not None:
|
||||||
|
shard_dim = 0 if not is_transposed else 1
|
||||||
|
if packed_dim == shard_dim:
|
||||||
shard_size //= self.quant_config.pack_factor
|
shard_size //= self.quant_config.pack_factor
|
||||||
offset //= self.quant_config.pack_factor
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
|
@ -298,17 +298,21 @@ class MistralForCausalLM(nn.Module):
|
|||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
revision: Optional[str] = None):
|
revision: Optional[str] = None):
|
||||||
if self.quant_config is None:
|
if self.quant_config is None:
|
||||||
weight_suffixes = ["weight"]
|
col_weight_suffixes = ["weight"]
|
||||||
|
row_weight_suffixes = ["weight"]
|
||||||
else:
|
else:
|
||||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
col_weight_suffixes = (
|
||||||
|
self.quant_config.get_col_parallel_tensor_names())
|
||||||
|
row_weight_suffixes = (
|
||||||
|
self.quant_config.get_row_parallel_tensor_names())
|
||||||
|
|
||||||
column_parallel_weights: List[str] = []
|
column_parallel_weights: List[str] = []
|
||||||
for layer in self._column_parallel_layers:
|
for layer in self._column_parallel_layers:
|
||||||
for suffix in weight_suffixes:
|
for suffix in col_weight_suffixes:
|
||||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
row_parallel_weights: List[str] = []
|
row_parallel_weights: List[str] = []
|
||||||
for layer in self._row_parallel_layers:
|
for layer in self._row_parallel_layers:
|
||||||
for suffix in weight_suffixes:
|
for suffix in row_weight_suffixes:
|
||||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -331,10 +335,10 @@ class MistralForCausalLM(nn.Module):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
is_packed = False
|
packed_dim = None
|
||||||
is_transposed = False
|
is_transposed = False
|
||||||
if self.quant_config is not None:
|
if self.quant_config is not None:
|
||||||
is_packed = self.quant_config.is_packed(name)
|
packed_dim = self.quant_config.get_packed_dim(name)
|
||||||
is_transposed = self.quant_config.is_transposed(name)
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
@ -348,7 +352,9 @@ class MistralForCausalLM(nn.Module):
|
|||||||
if is_transposed:
|
if is_transposed:
|
||||||
param = param.T
|
param = param.T
|
||||||
|
|
||||||
if is_packed:
|
if packed_dim is not None:
|
||||||
|
shard_dim = 0 if not is_transposed else 1
|
||||||
|
if packed_dim == shard_dim:
|
||||||
shard_size //= self.quant_config.pack_factor
|
shard_size //= self.quant_config.pack_factor
|
||||||
offset //= self.quant_config.pack_factor
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
|
@ -2,9 +2,11 @@ from typing import Type
|
|||||||
|
|
||||||
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
_QUANTIZATION_REGISTRY = {
|
_QUANTIZATION_REGISTRY = {
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
|
"squeezellm": SqueezeLLMConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,13 +60,17 @@ class AWQConfig(QuantizationConfig):
|
|||||||
return cls(weight_bits, group_size, zero_point)
|
return cls(weight_bits, group_size, zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_packed_tensor_names(cls) -> List[str]:
|
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||||
return ["qweight", "qzeros"]
|
return {"qweight": 1, "qzeros": 1}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_transposed_tensor_names(cls) -> List[str]:
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
return ["qweight", "qzeros", "scales"]
|
return ["qweight", "qzeros", "scales"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tp_tensor_names(cls) -> List[str]:
|
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||||
return ["qweight", "qzeros", "scales"]
|
return ["qweight", "qzeros", "scales"]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -45,19 +45,25 @@ class QuantizationConfig:
|
|||||||
"quantization config.")
|
"quantization config.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_packed_tensor_names(cls) -> List[str]:
|
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||||
|
"""Returns a dictionary of packed tensor names and their pack dims."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_packed(cls, tensor_name: str) -> bool:
|
def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
|
||||||
"""Returns True if a tensor is packed.
|
"""Returns the pack dim of a tensor if it is packed.
|
||||||
|
|
||||||
A tensor is considered packed if each element in the tensor is a
|
A tensor is considered packed if each element in the tensor is a
|
||||||
packed representation of multiple elements in the original tensor.
|
packed representation of multiple elements in the original tensor.
|
||||||
For example, an INT32 element in the tensor may represent 8 INT4
|
For example, an INT32 element in the tensor may represent 8 INT4
|
||||||
elements in the original tensor.
|
elements in the original tensor.
|
||||||
|
If the tensor is not packed, returns None.
|
||||||
"""
|
"""
|
||||||
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
packed_tensors = cls.get_packed_tensors()
|
||||||
|
for packed_tensor_name, pack_dim in packed_tensors.items():
|
||||||
|
if packed_tensor_name in tensor_name:
|
||||||
|
return pack_dim
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_transposed_tensor_names(cls) -> List[str]:
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
@ -71,5 +77,9 @@ class QuantizationConfig:
|
|||||||
for tag in cls.get_transposed_tensor_names())
|
for tag in cls.get_transposed_tensor_names())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tp_tensor_names(cls) -> List[str]:
|
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
65
vllm/model_executor/quantization_utils/squeezellm.py
Normal file
65
vllm/model_executor/quantization_utils/squeezellm.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeLLMConfig(QuantizationConfig):
|
||||||
|
"""Config class for SqueezeLLM.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/pdf/2306.07629
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
|
||||||
|
if self.weight_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 4-bit weight quantization is supported for "
|
||||||
|
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
||||||
|
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "squeezellm"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 70
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["quant_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||||
|
return cls(weight_bits)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||||
|
return {"qweight": 0}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "lookup_table"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight"]
|
Loading…
x
Reference in New Issue
Block a user