diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 5c928f27..70f716f9 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token( @pytest.mark.parametrize( "wNa16_args", - [ - ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), - ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), - ], + [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8, + True, False), + ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True, + False), + ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4, + True, False), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128, + 8, False, False), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", + "channel", None, 8, False, False), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", + "group", 128, 8, False, True)], ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): - model, strategy, group, pack_factor = wNa16_args + model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args with vllm_runner(model) as llm: def check_model(model): @@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): if group is None else group) assert qkv_proj.scheme.pack_factor == pack_factor + assert qkv_proj.scheme.symmetric == symmetric + assert qkv_proj.scheme.has_g_idx == has_g_idx llm.apply_model(check_model) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b714d95b..cb9a48d7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig): def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None - is_symmetric = weight_quant.symmetric is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value or weight_quant.strategy == QuantizationStrategy.GROUP.value) is_static = not weight_quant.dynamic - return (is_channel_group and input_quant_none and is_symmetric - and is_static) + return (is_channel_group and input_quant_none and is_static) def _get_scheme_from_parts( self, weight_quant: BaseModel, @@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig): if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, @@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, group_size=weight_quant.group_size, actorder=weight_quant.actorder) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 38df09ff..3535dd3f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, + PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter) +# yapf: enable from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, 8: scalar_types.uint8b128 } +WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) @@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): strategy: str, num_bits: int, group_size: Optional[int] = None, + symmetric: Optional[bool] = True, actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy + self.symmetric = symmetric self.group_size = -1 if group_size is None else group_size self.has_g_idx = actorder == ActivationOrdering.GROUP @@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): f"Unsupported num_bits = {num_bits}. " f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") - self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric else + WNA16_SUPPORTED_TYPES_MAP[num_bits]) @classmethod def get_min_capability(cls) -> int: @@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): weight_type=self.quant_type, act_type=params_dtype, group_size=self.group_size, - zero_points=False, + zero_points=not self.symmetric, has_g_idx=self.has_g_idx ) @@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): dtype=params_dtype, ) } + + zeros_args = { + "weight_loader": + weight_loader, + "data": + torch.zeros( + output_size_per_partition // self.pack_factor, + scales_and_zp_size, + dtype=torch.int32, + ) + } + if not partition_scales: weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) + + if not self.symmetric: + qzeros = PackedColumnParameter(output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) else: weight_scale = GroupQuantScaleParameter(output_dim=0, input_dim=1, **weight_scale_args) + if not self.symmetric: + qzeros = PackedvLLMParameter(input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) # A 2D array defining the original shape of the weights # before packing @@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + if not self.symmetric: + layer.register_parameter("weight_zero_point", qzeros) + # group index (for activation reordering) if self.has_g_idx: weight_g_idx = RowvLLMParameter(data=torch.empty( @@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", w_s_param_name="weight_scale", - w_zp_param_name=None, + w_zp_param_name="weight_zero_point", w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index 3f0586f6..b3ffeca4 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel): @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" - if c.zero_points: - return False, "Zero points currently not supported by "\ - " Compressed Tensors + Machete. (Kernel supports it"\ - " but CompressedTensorsWNA16 does not so support has"\ - " not been added to MacheteWNA16Kernel yet" + return False, "Zero points currently not supported by Machete" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index e21801cf..7bd824ff 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, - query_marlin_supported_quant_types) + marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) @@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel): @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: - if c.zero_points: - return False, "Zero points currently not supported by "\ - " MarlinLinearKernel. Will be added when AWQMarlin "\ - "is migrated over to using MPLinearKernel backend" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: @@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel): if self.w_zp_name is None: self.w_zp_name = "w_zp" - if c.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx( - getattr(layer, self.w_gidx_name)) - self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - else: - setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - if c.zero_points: - pass - # TODO (lucas): add the following when AWQMarlin is migrated over to - # using MPLinearKernel backend - # self._transform_param(layer, self.w_zp_name, lambda x: \ - # marlin_zero_points( - # x, - # size_k=c.partition_weight_shape[0], - # size_n=c.partition_weight_shape[1], - # num_bits=c.weight_type.size_bits)) - else: - setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) - def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) @@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel): group_size=c.group_size) return x + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + grouped_k = (c.partition_weight_shape[0] // + c.group_size if c.group_size != -1 else 1) + self._transform_param(layer, self.w_zp_name, lambda x: \ + marlin_zero_points( + unpack_cols(x.t(), c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1]), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) @@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel): wtype=c.weight_type, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], + has_zp=self.config.zero_points, is_k_full=self.is_k_full, bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 1ccfae91..4a190480 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -332,6 +332,7 @@ def apply_gptq_marlin_linear( wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, + has_zp: bool, is_k_full: bool, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: @@ -356,8 +357,8 @@ def apply_gptq_marlin_linear( size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=is_k_full, - has_zp=False, use_atomic_add=use_atomic_add, + has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, is_zp_float=False)