Support SqueezeLLM (#1326)

Co-authored-by: squeeze-ai-lab <squeezeailab.bair@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
chooper1 2023-10-22 03:14:59 -03:00 committed by GitHub
parent bf31d3606a
commit 1f24755bf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 378 additions and 40 deletions

View File

@ -70,7 +70,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)

View File

@ -201,7 +201,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",

View File

@ -7,9 +7,13 @@ torch::Tensor awq_gemm(
torch::Tensor _zeros,
int split_k_iters);
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
}

View File

@ -0,0 +1,148 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace vllm {
namespace squeezellm {
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
const half2* __restrict__ vec,
const int* __restrict__ mat,
half2* __restrict__ mul,
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ half2 blockvec[blockwidth2];
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}
__half res;
half2 res2;
half2 tmp2;
int i;
int k;
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b){
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
res2 = {};
tmp2 = {};
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
res = __hadd(__hadd(res2.x, res2.y), res);
i += width;
k += 4;
}
// col%2 -> only set one of the two values
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
atomicAdd(&mul[b * width / 2 + col / 2], res3);
}
}
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
(half2*) vec.data<at::Half>(),
mat.data_ptr<int>(),
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
height, width, batch, vec_height
);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4

View File

@ -200,6 +200,7 @@ quantization_extension = CUDAExtension(
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,

View File

@ -103,7 +103,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
supported_quantization = ["awq", "squeezellm"]
if self.quantization is None:
return
quantization = self.quantization.lower()

View File

@ -168,7 +168,7 @@ class EngineArgs:
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')
return parser

View File

@ -1,10 +1,14 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}

View File

@ -11,9 +11,11 @@ from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
assert self.input_size % self.quant_config.group_size == 0
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size,
@ -62,9 +64,11 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
if self.input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,

View File

@ -0,0 +1,84 @@
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size // self.quant_config.pack_factor,
self.output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size_per_partition,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class SqueezeLLMRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
return out.reshape(out_shape)

View File

@ -314,17 +314,21 @@ class LlamaForCausalLM(nn.Module):
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
col_weight_suffixes = ["weight"]
row_weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
col_weight_suffixes = (
self.quant_config.get_col_parallel_tensor_names())
row_weight_suffixes = (
self.quant_config.get_row_parallel_tensor_names())
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
for suffix in col_weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
for suffix in row_weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
@ -351,10 +355,10 @@ class LlamaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
is_packed = False
packed_dim = None
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
@ -368,9 +372,11 @@ class LlamaForCausalLM(nn.Module):
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if weight_name in ["k_proj", "v_proj"]:
shard_id = tp_rank // num_kv_heads_replicas

View File

@ -298,17 +298,21 @@ class MistralForCausalLM(nn.Module):
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
col_weight_suffixes = ["weight"]
row_weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
col_weight_suffixes = (
self.quant_config.get_col_parallel_tensor_names())
row_weight_suffixes = (
self.quant_config.get_row_parallel_tensor_names())
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
for suffix in col_weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
for suffix in row_weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
@ -331,10 +335,10 @@ class MistralForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
is_packed = False
packed_dim = None
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
@ -348,9 +352,11 @@ class MistralForCausalLM(nn.Module):
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *

View File

@ -2,9 +2,11 @@ from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}

View File

@ -60,13 +60,17 @@ class AWQConfig(QuantizationConfig):
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros"]
def get_packed_tensors(cls) -> Dict[str, int]:
return {"qweight": 1, "qzeros": 1}
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
def get_col_parallel_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_row_parallel_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import torch
@ -45,19 +45,25 @@ class QuantizationConfig:
"quantization config.")
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
def get_packed_tensors(cls) -> Dict[str, int]:
"""Returns a dictionary of packed tensor names and their pack dims."""
raise NotImplementedError
@classmethod
def is_packed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is packed.
def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
"""Returns the pack dim of a tensor if it is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
If the tensor is not packed, returns None.
"""
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
packed_tensors = cls.get_packed_tensors()
for packed_tensor_name, pack_dim in packed_tensors.items():
if packed_tensor_name in tensor_name:
return pack_dim
return None
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
@ -71,5 +77,9 @@ class QuantizationConfig:
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
def get_col_parallel_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def get_row_parallel_tensor_names(cls) -> List[str]:
raise NotImplementedError

View File

@ -0,0 +1,65 @@
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
@classmethod
def get_name(cls) -> str:
return "squeezellm"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
@classmethod
def get_packed_tensors(cls) -> Dict[str, int]:
return {"qweight": 0}
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight"]
@classmethod
def get_col_parallel_tensor_names(cls) -> List[str]:
return ["qweight", "lookup_table"]
@classmethod
def get_row_parallel_tensor_names(cls) -> List[str]:
return ["qweight"]