[Misc] Update compressed-tensors
WNA16 to support zero-points (#14211)
This commit is contained in:
parent
280d62b8a2
commit
54a66e5fee
@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wNa16_args",
|
||||
[
|
||||
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
|
||||
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
|
||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
|
||||
],
|
||||
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8,
|
||||
True, False),
|
||||
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True,
|
||||
False),
|
||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4,
|
||||
True, False),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128,
|
||||
8, False, False),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel",
|
||||
"channel", None, 8, False, False),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
|
||||
"group", 128, 8, False, True)],
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="The tests are skipped on non-CUDA platform.")
|
||||
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||
model, strategy, group, pack_factor = wNa16_args
|
||||
model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
|
||||
with vllm_runner(model) as llm:
|
||||
|
||||
def check_model(model):
|
||||
@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||
if group is None else group)
|
||||
|
||||
assert qkv_proj.scheme.pack_factor == pack_factor
|
||||
assert qkv_proj.scheme.symmetric == symmetric
|
||||
assert qkv_proj.scheme.has_g_idx == has_g_idx
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
|
@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
is_symmetric = weight_quant.symmetric
|
||||
is_channel_group = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||
is_static = not weight_quant.dynamic
|
||||
|
||||
return (is_channel_group and input_quant_none and is_symmetric
|
||||
and is_static)
|
||||
return (is_channel_group and input_quant_none and is_static)
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel,
|
||||
@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
|
@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128
|
||||
}
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None):
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
||||
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||
if not self.symmetric else
|
||||
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=False,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx
|
||||
)
|
||||
|
||||
@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name=None,
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
|
@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Machete, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by "\
|
||||
" Compressed Tensors + Machete. (Kernel supports it"\
|
||||
" but CompressedTensorsWNA16 does not so support has"\
|
||||
" not been added to MacheteWNA16Kernel yet"
|
||||
return False, "Zero points currently not supported by Machete"
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(
|
||||
c.zero_points):
|
||||
|
@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||
query_marlin_supported_quant_types)
|
||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by "\
|
||||
" MarlinLinearKernel. Will be added when AWQMarlin "\
|
||||
"is migrated over to using MPLinearKernel backend"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
pass
|
||||
# TODO (lucas): add the following when AWQMarlin is migrated over to
|
||||
# using MPLinearKernel backend
|
||||
# self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
# marlin_zero_points(
|
||||
# x,
|
||||
# size_k=c.partition_weight_shape[0],
|
||||
# size_n=c.partition_weight_shape[1],
|
||||
# num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
group_size=c.group_size)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (c.partition_weight_shape[0] //
|
||||
c.group_size if c.group_size != -1 else 1)
|
||||
self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
marlin_zero_points(
|
||||
unpack_cols(x.t(), c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1]),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
has_zp=self.config.zero_points,
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
|
@ -332,6 +332,7 @@ def apply_gptq_marlin_linear(
|
||||
wtype: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
has_zp: bool,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
@ -356,8 +357,8 @@ def apply_gptq_marlin_linear(
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_atomic_add=use_atomic_add,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user