[Kernel] Initial Activation Quantization Support (#4525)

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2024-05-23 17:29:18 -04:00 committed by GitHub
parent 5eda2ea02a
commit a1242324c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 683 additions and 94 deletions

View File

@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"

View File

@ -93,6 +93,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);

View File

@ -67,6 +67,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,

View File

@ -0,0 +1,59 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cmath>
#include "../../dispatch_utils.h"
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static const float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(x);
// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
scale_type scale, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), scale,
hidden_size);
});
}

View File

@ -0,0 +1,31 @@
import pytest
import torch
from vllm._C import ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
seed: int, scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
out1 = (x / scale).round().clamp(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors

View File

@ -0,0 +1,36 @@
"""Test model set-up and weight loading for sparseml-quantized models.
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed"
llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32

View File

@ -251,6 +251,24 @@ def scaled_fp8_quant(
return output, scale
# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
Returns:
torch.Tensor: Output tensor in int8.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,

View File

@ -56,7 +56,6 @@ class LinearMethodBase(QuantizeMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
@ -77,8 +76,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
@ -149,15 +147,13 @@ class ReplicatedLinear(LinearBase):
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
@ -210,17 +206,15 @@ class ColumnParallelLinear(LinearBase):
the list would be size 3.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
):
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
@ -228,18 +222,26 @@ class ColumnParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
for output_size in self.output_sizes
]
if output_sizes is None:
output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
@ -317,22 +319,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, quant_config,
self.output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
def weight_loader(self,
param: Parameter,
@ -343,6 +347,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
param_shard_splitter = getattr(param, "shard_splitter", None)
if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
@ -403,6 +427,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
@ -415,6 +446,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if fp8_scales_shard_indexer is None:
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@ -443,17 +482,15 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
@ -473,14 +510,19 @@ class QKVParallelLinear(ColumnParallelLinear):
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
output_sizes = [
self.num_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, quant_config, output_sizes)
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
def weight_loader(self,
param: Parameter,
@ -490,6 +532,26 @@ class QKVParallelLinear(ColumnParallelLinear):
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
param_shard_splitter = getattr(param, "shard_splitter", None)
if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
@ -528,6 +590,8 @@ class QKVParallelLinear(ColumnParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
@ -567,6 +631,12 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
@ -578,6 +648,13 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@ -608,17 +685,15 @@ class RowParallelLinear(LinearBase):
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
@ -628,16 +703,15 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
@ -665,12 +739,16 @@ class RowParallelLinear(LinearBase):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

View File

@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
@ -27,6 +29,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"sparseml": CompressedTensorsConfig,
}

View File

@ -0,0 +1,151 @@
from typing import Any, Dict, List, Optional
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor)
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
self.ignore = ignore
self.layer_quant_details = layer_quant_details
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16]
# Need to figure it out
def get_min_capability(self) -> int:
return 60
def get_name(self) -> str:
return "compressed_tensors"
def get_quant_method(
self, layer: torch.nn.Module
) -> Optional["CompressedTensorsLinearMethod"]:
if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self)
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
for key, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target]["weight"] = quant_config.get(
"weights")
layer_quant_details[target]["input"] = quant_config.get(
"input_activations")
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
def _get_schema(self, weight_quant: Dict, input_quant: Dict):
# TODO: Refactor as additional cases are supported
weight_bit = weight_quant.get("num_bits")
input_bit = input_quant.get("num_bits")
weight_strategy = weight_quant.get("strategy")
input_strategy = input_quant.get("strategy")
weight_symmetric = weight_quant.get("symmetric")
input_symmetric = input_quant.get("symmetric")
is_8_bits = weight_bit == input_bit == 8
is_tensor = weight_strategy == input_strategy == "tensor"
is_symmetric = weight_symmetric and input_symmetric
if is_8_bits and is_tensor and is_symmetric and \
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
# now.
return CompressedTensorsW8A8StaticTensor()
raise NotImplementedError(
"Scheme not supported. Only CUDA, 8-bit static symmtetric "
"per tensor quantization is currently supported")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
# TODO: update with matching function from `compressed_tensors`
layer_type_name = None
layer_name_class = type(layer).__name__.lower()
for target in self.layer_quant_details:
if target.lower() in layer_name_class:
layer_type_name = target
break
if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}")
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
layer_type_name, None)
if layer_quant_details is None:
raise ValueError(
f"Could not find quantization details for {layer}.")
return self._get_schema(weight_quant=layer_quant_details["weight"],
input_quant=layer_quant_details["input"])
class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer.
"""
weight_loader = extra_weight_attrs.get("weight_loader")
scheme = self.quantization_config.get_scheme(layer=layer)
scheme.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
layer.scheme = scheme
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input.
"""
if bias is not None:
raise ValueError("bias is not supported for this linear method")
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x)

View File

@ -0,0 +1,5 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor)

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
import torch
__all__ = ["CompressedTensorsScheme"]
class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: toch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
"""
raise NotImplementedError

View File

@ -0,0 +1,39 @@
from typing import Callable, List
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsUnquantized"]
class CompressedTensorsUnquantized(CompressedTensorsScheme):
"""
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
device="cuda",
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
return F.linear(x, weight)

View File

@ -0,0 +1,119 @@
from typing import Callable, List, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A8StaticTensor"]
class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
# TODO: remove zero_point parameters once the configs given remove them
# Note on input/weight scales and zero_points
#
# When the scales have a single value, it is required that they be
# on the CPU for 2 reasons,
# 1. Performance:
# When the scales (input_scale/weight_scales) have only a single
# value, we perform a scalar broadcast of that value during the
# quant/dequant operations. The "quant" and the "gemm+dequant"
# kernels accept the Scalar by-value. These tensors are allocated
# on the CPU in order to avoid the GPU-to-CPU copy when passing
# by-value.
#
# 2. CUDA Graphs:
# CUDA Graphs don't support GPU-to-CPU copy operations during
# stream capture.
#
# TODO: zero-points are not supported yet. But we expect a similar
# pattern.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda"
input_scale = Parameter(torch.empty(1,
device="cpu",
dtype=torch.float32),
requires_grad=False)
input_zero_point = Parameter(torch.empty(1,
device="cpu",
dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim,
device=weight_scale_device,
dtype=torch.float32),
requires_grad=False)
weight_zero_point = Parameter(torch.empty(1,
device="cpu",
dtype=torch.int8),
requires_grad=False)
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, {"weight_loader": weight_loader})
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {"weight_loader": weight_loader})
layer.register_parameter("input_zero_point", input_zero_point)
set_weight_attrs(input_zero_point, {"weight_loader": weight_loader})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
act_scale = layer.input_scale
# Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)

View File

@ -120,6 +120,13 @@ def get_quant_config(model_config: ModelConfig,
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if hf_quant_config is None:
compression_config = getattr(model_config.hf_config,
"compression_config", None)
if compression_config is not None:
hf_quant_config = compression_config.get("quantization_config",
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model

View File

@ -62,11 +62,12 @@ class LlamaMLP(nn.Module):
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config)
if hidden_act != "silu":
@ -120,16 +121,16 @@ class LlamaAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
)
@ -263,8 +264,10 @@ class LlamaModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)