[Misc] Update compressed-tensors
WNA16 to support zero-points (#14211)
This commit is contained in:
parent
280d62b8a2
commit
54a66e5fee
@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"wNa16_args",
|
"wNa16_args",
|
||||||
[
|
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8,
|
||||||
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
|
True, False),
|
||||||
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
|
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True,
|
||||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
|
False),
|
||||||
],
|
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4,
|
||||||
|
True, False),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128,
|
||||||
|
8, False, False),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel",
|
||||||
|
"channel", None, 8, False, False),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
|
||||||
|
"group", 128, 8, False, True)],
|
||||||
)
|
)
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
reason="The tests are skipped on non-CUDA platform.")
|
reason="The tests are skipped on non-CUDA platform.")
|
||||||
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||||
model, strategy, group, pack_factor = wNa16_args
|
model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
|||||||
if group is None else group)
|
if group is None else group)
|
||||||
|
|
||||||
assert qkv_proj.scheme.pack_factor == pack_factor
|
assert qkv_proj.scheme.pack_factor == pack_factor
|
||||||
|
assert qkv_proj.scheme.symmetric == symmetric
|
||||||
|
assert qkv_proj.scheme.has_g_idx == has_g_idx
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
|
@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
input_quant_none = input_quant is None
|
input_quant_none = input_quant is None
|
||||||
is_symmetric = weight_quant.symmetric
|
|
||||||
is_channel_group = (
|
is_channel_group = (
|
||||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||||
is_static = not weight_quant.dynamic
|
is_static = not weight_quant.dynamic
|
||||||
|
|
||||||
return (is_channel_group and input_quant_none and is_symmetric
|
return (is_channel_group and input_quant_none and is_static)
|
||||||
and is_static)
|
|
||||||
|
|
||||||
def _get_scheme_from_parts(
|
def _get_scheme_from_parts(
|
||||||
self, weight_quant: BaseModel,
|
self, weight_quant: BaseModel,
|
||||||
@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||||
|
assert weight_quant.symmetric
|
||||||
return CompressedTensorsW4A16Sparse24(
|
return CompressedTensorsW4A16Sparse24(
|
||||||
strategy=weight_quant.strategy,
|
strategy=weight_quant.strategy,
|
||||||
num_bits=weight_quant.num_bits,
|
num_bits=weight_quant.num_bits,
|
||||||
@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return CompressedTensorsWNA16(
|
return CompressedTensorsWNA16(
|
||||||
num_bits=weight_quant.num_bits,
|
num_bits=weight_quant.num_bits,
|
||||||
strategy=weight_quant.strategy,
|
strategy=weight_quant.strategy,
|
||||||
|
symmetric=weight_quant.symmetric,
|
||||||
group_size=weight_quant.group_size,
|
group_size=weight_quant.group_size,
|
||||||
actorder=weight_quant.actorder)
|
actorder=weight_quant.actorder)
|
||||||
|
|
||||||
|
@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
|||||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
marlin_repeat_scales_on_all_ranks)
|
marlin_repeat_scales_on_all_ranks)
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
|
PackedColumnParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
RowvLLMParameter)
|
RowvLLMParameter)
|
||||||
|
# yapf: enable
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
|
|||||||
4: scalar_types.uint4b8,
|
4: scalar_types.uint4b8,
|
||||||
8: scalar_types.uint8b128
|
8: scalar_types.uint8b128
|
||||||
}
|
}
|
||||||
|
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
strategy: str,
|
strategy: str,
|
||||||
num_bits: int,
|
num_bits: int,
|
||||||
group_size: Optional[int] = None,
|
group_size: Optional[int] = None,
|
||||||
|
symmetric: Optional[bool] = True,
|
||||||
actorder: Optional[ActivationOrdering] = None):
|
actorder: Optional[ActivationOrdering] = None):
|
||||||
|
|
||||||
self.pack_factor = 32 // num_bits
|
self.pack_factor = 32 // num_bits
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
|
self.symmetric = symmetric
|
||||||
self.group_size = -1 if group_size is None else group_size
|
self.group_size = -1 if group_size is None else group_size
|
||||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||||
|
|
||||||
@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
f"Unsupported num_bits = {num_bits}. "
|
f"Unsupported num_bits = {num_bits}. "
|
||||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||||
|
|
||||||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||||
|
if not self.symmetric else
|
||||||
|
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
weight_type=self.quant_type,
|
weight_type=self.quant_type,
|
||||||
act_type=params_dtype,
|
act_type=params_dtype,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
zero_points=False,
|
zero_points=not self.symmetric,
|
||||||
has_g_idx=self.has_g_idx
|
has_g_idx=self.has_g_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
zeros_args = {
|
||||||
|
"weight_loader":
|
||||||
|
weight_loader,
|
||||||
|
"data":
|
||||||
|
torch.zeros(
|
||||||
|
output_size_per_partition // self.pack_factor,
|
||||||
|
scales_and_zp_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if not partition_scales:
|
if not partition_scales:
|
||||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||||
**weight_scale_args)
|
**weight_scale_args)
|
||||||
|
|
||||||
|
if not self.symmetric:
|
||||||
|
qzeros = PackedColumnParameter(output_dim=0,
|
||||||
|
packed_dim=0,
|
||||||
|
packed_factor=self.pack_factor,
|
||||||
|
**zeros_args)
|
||||||
else:
|
else:
|
||||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
**weight_scale_args)
|
**weight_scale_args)
|
||||||
|
if not self.symmetric:
|
||||||
|
qzeros = PackedvLLMParameter(input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
packed_dim=0,
|
||||||
|
packed_factor=self.pack_factor,
|
||||||
|
**zeros_args)
|
||||||
|
|
||||||
# A 2D array defining the original shape of the weights
|
# A 2D array defining the original shape of the weights
|
||||||
# before packing
|
# before packing
|
||||||
@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
layer.register_parameter("weight_shape", weight_shape)
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
|
|
||||||
|
if not self.symmetric:
|
||||||
|
layer.register_parameter("weight_zero_point", qzeros)
|
||||||
|
|
||||||
# group index (for activation reordering)
|
# group index (for activation reordering)
|
||||||
if self.has_g_idx:
|
if self.has_g_idx:
|
||||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||||
@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||||
w_q_param_name="weight_packed",
|
w_q_param_name="weight_packed",
|
||||||
w_s_param_name="weight_scale",
|
w_s_param_name="weight_scale",
|
||||||
w_zp_param_name=None,
|
w_zp_param_name="weight_zero_point",
|
||||||
w_gidx_param_name="weight_g_idx")
|
w_gidx_param_name="weight_g_idx")
|
||||||
|
|
||||||
# Checkpoints are serialized in compressed-tensors format, which is
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
|
@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls,
|
def can_implement(cls,
|
||||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
|
||||||
if c.has_g_idx and\
|
if c.has_g_idx and\
|
||||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||||
return False, "Act reordering currently not supported by Machete, "\
|
return False, "Act reordering currently not supported by Machete, "\
|
||||||
"when the input features are partitioned across "\
|
"when the input features are partitioned across "\
|
||||||
"devices"
|
"devices"
|
||||||
|
|
||||||
if c.zero_points:
|
if c.zero_points:
|
||||||
return False, "Zero points currently not supported by "\
|
return False, "Zero points currently not supported by Machete"
|
||||||
" Compressed Tensors + Machete. (Kernel supports it"\
|
|
||||||
" but CompressedTensorsWNA16 does not so support has"\
|
|
||||||
" not been added to MacheteWNA16Kernel yet"
|
|
||||||
|
|
||||||
if c.weight_type not in query_machete_supported_quant_types(
|
if c.weight_type not in query_machete_supported_quant_types(
|
||||||
c.zero_points):
|
c.zero_points):
|
||||||
|
@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||||
query_marlin_supported_quant_types)
|
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
permute_param_layout_)
|
permute_param_layout_)
|
||||||
|
|
||||||
@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls,
|
def can_implement(cls,
|
||||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
if c.zero_points:
|
|
||||||
return False, "Zero points currently not supported by "\
|
|
||||||
" MarlinLinearKernel. Will be added when AWQMarlin "\
|
|
||||||
"is migrated over to using MPLinearKernel backend"
|
|
||||||
|
|
||||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||||
if c.weight_type not in quant_types:
|
if c.weight_type not in quant_types:
|
||||||
@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
if self.w_zp_name is None:
|
if self.w_zp_name is None:
|
||||||
self.w_zp_name = "w_zp"
|
self.w_zp_name = "w_zp"
|
||||||
|
|
||||||
if c.has_g_idx:
|
|
||||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
|
||||||
getattr(layer, self.w_gidx_name))
|
|
||||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
|
||||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
|
||||||
else:
|
|
||||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
|
||||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
|
||||||
|
|
||||||
if c.zero_points:
|
|
||||||
pass
|
|
||||||
# TODO (lucas): add the following when AWQMarlin is migrated over to
|
|
||||||
# using MPLinearKernel backend
|
|
||||||
# self._transform_param(layer, self.w_zp_name, lambda x: \
|
|
||||||
# marlin_zero_points(
|
|
||||||
# x,
|
|
||||||
# size_k=c.partition_weight_shape[0],
|
|
||||||
# size_n=c.partition_weight_shape[1],
|
|
||||||
# num_bits=c.weight_type.size_bits))
|
|
||||||
else:
|
|
||||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
|
||||||
|
|
||||||
def transform_w_q(x):
|
def transform_w_q(x):
|
||||||
assert isinstance(x, BasevLLMParameter)
|
assert isinstance(x, BasevLLMParameter)
|
||||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
group_size=c.group_size)
|
group_size=c.group_size)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
if c.has_g_idx:
|
||||||
|
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||||
|
getattr(layer, self.w_gidx_name))
|
||||||
|
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||||
|
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||||
|
else:
|
||||||
|
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||||
|
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||||
|
|
||||||
|
if c.zero_points:
|
||||||
|
grouped_k = (c.partition_weight_shape[0] //
|
||||||
|
c.group_size if c.group_size != -1 else 1)
|
||||||
|
self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||||
|
marlin_zero_points(
|
||||||
|
unpack_cols(x.t(), c.weight_type.size_bits,
|
||||||
|
grouped_k,
|
||||||
|
c.partition_weight_shape[1]),
|
||||||
|
size_k=grouped_k,
|
||||||
|
size_n=c.partition_weight_shape[1],
|
||||||
|
num_bits=c.weight_type.size_bits))
|
||||||
|
else:
|
||||||
|
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
|
||||||
@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
wtype=c.weight_type,
|
wtype=c.weight_type,
|
||||||
input_size_per_partition=c.partition_weight_shape[0],
|
input_size_per_partition=c.partition_weight_shape[0],
|
||||||
output_size_per_partition=c.partition_weight_shape[1],
|
output_size_per_partition=c.partition_weight_shape[1],
|
||||||
|
has_zp=self.config.zero_points,
|
||||||
is_k_full=self.is_k_full,
|
is_k_full=self.is_k_full,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
@ -332,6 +332,7 @@ def apply_gptq_marlin_linear(
|
|||||||
wtype: ScalarType,
|
wtype: ScalarType,
|
||||||
output_size_per_partition: int,
|
output_size_per_partition: int,
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
|
has_zp: bool,
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||||
@ -356,8 +357,8 @@ def apply_gptq_marlin_linear(
|
|||||||
size_n=output_size_per_partition,
|
size_n=output_size_per_partition,
|
||||||
size_k=input_size_per_partition,
|
size_k=input_size_per_partition,
|
||||||
is_k_full=is_k_full,
|
is_k_full=is_k_full,
|
||||||
has_zp=False,
|
|
||||||
use_atomic_add=use_atomic_add,
|
use_atomic_add=use_atomic_add,
|
||||||
|
has_zp=has_zp,
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False)
|
is_zp_float=False)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user