[Misc] Update compressed-tensors WNA16 to support zero-points (#14211)

This commit is contained in:
Dipika Sikka 2025-04-15 09:33:51 -04:00 committed by GitHub
parent 280d62b8a2
commit 54a66e5fee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 85 additions and 45 deletions

View File

@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wNa16_args", "wNa16_args",
[ [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8,
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), True, False),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True,
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), False),
], ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4,
True, False),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128,
8, False, False),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel",
"channel", None, 8, False, False),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
"group", 128, 8, False, True)],
) )
@pytest.mark.skipif(not current_platform.is_cuda(), @pytest.mark.skipif(not current_platform.is_cuda(),
reason="The tests are skipped on non-CUDA platform.") reason="The tests are skipped on non-CUDA platform.")
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
def check_model(model): def check_model(model):
@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
if group is None else group) if group is None else group)
assert qkv_proj.scheme.pack_factor == pack_factor assert qkv_proj.scheme.pack_factor == pack_factor
assert qkv_proj.scheme.symmetric == symmetric
assert qkv_proj.scheme.has_g_idx == has_g_idx
llm.apply_model(check_model) llm.apply_model(check_model)

View File

@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None input_quant_none = input_quant is None
is_symmetric = weight_quant.symmetric
is_channel_group = ( is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value) or weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_static = not weight_quant.dynamic is_static = not weight_quant.dynamic
return (is_channel_group and input_quant_none and is_symmetric return (is_channel_group and input_quant_none and is_static)
and is_static)
def _get_scheme_from_parts( def _get_scheme_from_parts(
self, weight_quant: BaseModel, self, weight_quant: BaseModel,
@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return CompressedTensorsWNA16( return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size, group_size=weight_quant.group_size,
actorder=weight_quant.actorder) actorder=weight_quant.actorder)

View File

@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel) MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks) marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter, PackedvLLMParameter,
RowvLLMParameter) RowvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8, 4: scalar_types.uint4b8,
8: scalar_types.uint8b128 8: scalar_types.uint8b128
} }
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy: str, strategy: str,
num_bits: int, num_bits: int,
group_size: Optional[int] = None, group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None): actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP self.has_g_idx = actorder == ActivationOrdering.GROUP
@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
f"Unsupported num_bits = {num_bits}. " f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
if not self.symmetric else
WNA16_SUPPORTED_TYPES_MAP[num_bits])
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_type=self.quant_type, weight_type=self.quant_type,
act_type=params_dtype, act_type=params_dtype,
group_size=self.group_size, group_size=self.group_size,
zero_points=False, zero_points=not self.symmetric,
has_g_idx=self.has_g_idx has_g_idx=self.has_g_idx
) )
@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype=params_dtype, dtype=params_dtype,
) )
} }
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
output_size_per_partition // self.pack_factor,
scales_and_zp_size,
dtype=torch.int32,
)
}
if not partition_scales: if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0, weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args) **weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else: else:
weight_scale = GroupQuantScaleParameter(output_dim=0, weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1, input_dim=1,
**weight_scale_args) **weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights # A 2D array defining the original shape of the weights
# before packing # before packing
@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering) # group index (for activation reordering)
if self.has_g_idx: if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty( weight_g_idx = RowvLLMParameter(data=torch.empty(
@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.kernel = kernel_type(mp_linear_kernel_config, self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed", w_q_param_name="weight_packed",
w_s_param_name="weight_scale", w_s_param_name="weight_scale",
w_zp_param_name=None, w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx") w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is # Checkpoints are serialized in compressed-tensors format, which is

View File

@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\ if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]: c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\ return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\ "when the input features are partitioned across "\
"devices" "devices"
if c.zero_points: if c.zero_points:
return False, "Zero points currently not supported by "\ return False, "Zero points currently not supported by Machete"
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types( if c.weight_type not in query_machete_supported_quant_types(
c.zero_points): c.zero_points):

View File

@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types) marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_) permute_param_layout_)
@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points) quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types: if c.weight_type not in quant_types:
@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
if self.w_zp_name is None: if self.w_zp_name is None:
self.w_zp_name = "w_zp" self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x): def transform_w_q(x):
assert isinstance(x, BasevLLMParameter) assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
group_size=c.group_size) group_size=c.group_size)
return x return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (c.partition_weight_shape[0] //
c.group_size if c.group_size != -1 else 1)
self._transform_param(layer, self.w_zp_name, lambda x: \
marlin_zero_points(
unpack_cols(x.t(), c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1]),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type, wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0], input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1], output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
bias=bias) bias=bias)

View File

@ -332,6 +332,7 @@ def apply_gptq_marlin_linear(
wtype: ScalarType, wtype: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
has_zp: bool,
is_k_full: bool, is_k_full: bool,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
@ -356,8 +357,8 @@ def apply_gptq_marlin_linear(
size_n=output_size_per_partition, size_n=output_size_per_partition,
size_k=input_size_per_partition, size_k=input_size_per_partition,
is_k_full=is_k_full, is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False) is_zp_float=False)