[Misc] Remove SqueezeLLM (#8220)

This commit is contained in:
Dipika Sikka 2024-09-06 18:29:03 -04:00 committed by GitHub
parent 9db52eab3d
commit 23f322297f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 6 additions and 389 deletions

View File

@ -181,7 +181,6 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu" "csrc/quantization/fp8/common.cu"

View File

@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales); torch::Tensor& scales);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,

View File

@ -1,216 +0,0 @@
#include <torch/all.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace vllm {
namespace squeezellm {
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul,
#else
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table, int height, int width, int batch,
int vec_height) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
#else
__shared__ __half2 blockvec[blockwidth2];
#endif
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}
__half res;
#ifndef USE_ROCM
half2 res2;
half2 tmp2;
#else
__half2 res2;
__half2 tmp2;
#endif
int i;
int k;
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b) {
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
#ifndef USE_ROCM
res2 = {};
tmp2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
tmp2.x = __half_as_ushort(__float2half(0));
tmp2.y = __half_as_ushort(__float2half(0));
#endif
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif
i += width;
k += 4;
}
// col%2 -> only set one of the two values
#ifndef USE_ROCM
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
#else
__half2 res3;
res3.x = __half_as_ushort(__float2half(0));
res3.y = __half_as_ushort(__float2half(0));
if (col % 2 == 0) {
res3.x = __half_as_ushort(res);
} else {
res3.y = __half_as_ushort(res);
}
#endif
#ifndef USE_ROCM
atomicAdd(&mul[b * width / 2 + col / 2], res3);
#else
int tmp_addr = b * width / 2 + col / 2;
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
}
}
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*)vec.data_ptr<at::Half>(),
#else
(__half2*)vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*)mul.data_ptr<at::Half>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#else
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4

View File

@ -237,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"lookup_table) -> ()");
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
// Compute FP8 quantized tensor for given scaling factor. // Compute FP8 quantized tensor for given scaling factor.
ops.def( ops.def(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");

View File

@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations
- ✗ - ✗
- ✗ - ✗
- ✗ - ✗
* - SqueezeLLM
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✗
- ✗
- ✗
- ✗
- ✗
Notes: Notes:
^^^^^^ ^^^^^^

View File

@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various
python3 benchmarks/benchmark_throughput.py --help python3 benchmarks/benchmark_throughput.py --help
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] [--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
[--quantization-param-path KV_CACHE_quantization_param_path] [--quantization-param-path KV_CACHE_quantization_param_path]
@ -76,7 +76,7 @@ optional arguments:
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
--model MODEL --model MODEL
--tokenizer TOKENIZER --tokenizer TOKENIZER
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None} --quantization {awq,gptq,None}, -q {awq,gptq,None}
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
--n N Number of generated sequences per prompt. --n N Number of generated sequences per prompt.
--use-beam-search --use-beam-search

View File

@ -209,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin # marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,

View File

@ -277,7 +277,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"] rocm_supported_quantization = ["awq", "gptq", "fp8"]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors", "fbgemm_fp8", "compressed_tensors", "compressed-tensors",
@ -1537,7 +1537,7 @@ class LoRAConfig:
if model_config.quantization and model_config.quantization not in [ if model_config.quantization and model_config.quantization not in [
"awq", "gptq" "awq", "gptq"
]: ]:
# TODO support marlin and squeezellm # TODO support marlin
logger.warning("%s quantization is not tested with LoRA yet.", logger.warning("%s quantization is not tested with LoRA yet.",
model_config.quantization) model_config.quantization)

View File

@ -55,7 +55,7 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
quantization: The method used to quantize the model weights. Currently, quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq", "squeezellm", and "fp8" (experimental). we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of not quantized and use `dtype` to determine the data type of

View File

@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear # unquantizedLinear
if hasattr(base_layer, "weight"): if hasattr(base_layer, "weight"):
return base_layer.weight.device return base_layer.weight.device
# GPTQ/AWQ/SqueezeLLM # GPTQ/AWQ
elif hasattr(base_layer, "qweight"): elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device return base_layer.qweight.device
# marlin # marlin

View File

@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.neuron_quant import ( from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig) NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
@ -43,7 +42,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,

View File

@ -1,138 +0,0 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
def get_name(self) -> str:
return "squeezellm"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
lookup_table = Parameter(
torch.empty(
output_size,
self.quant_config.weight_bits**2,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)