[Misc] Update w4a16 compressed-tensors support to include w8a16 (#5794)

This commit is contained in:
Dipika Sikka 2024-06-25 15:23:35 -04:00 committed by GitHub
parent d9b34baedd
commit dd248f7675
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 26 deletions

View File

@ -8,9 +8,9 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16, CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsW8A8StaticTensor) CompressedTensorsWNA16)
@pytest.mark.parametrize("model_args", [ @pytest.mark.parametrize("model_args", [
@ -74,26 +74,27 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
assert qkv_proj.weight.dtype is torch.int8 assert qkv_proj.weight.dtype is torch.int8
@pytest.mark.parametrize("w4a16_args", [ @pytest.mark.parametrize(
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None), "wNa16_args",
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128), [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
]) ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
def test_compressed_tensors_w4a16(vllm_runner, w4a16_args): ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
model, strategy, group = w4a16_args def test_compressed_tensors_w4a16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0] layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16) assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
assert qkv_proj.scheme.strategy == strategy assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == group assert qkv_proj.scheme.group_size == group
assert qkv_proj.weight_packed.dtype is torch.int32 assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16 assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == 8 assert qkv_proj.weight_packed.pack_factor == pack_factor
def test_compressed_tensors_w4a16_marlin24(vllm_runner): def test_compressed_tensors_w4a16_marlin24(vllm_runner):

View File

@ -7,9 +7,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW4A16, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8StaticTensor) CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match) find_first_name_or_class_match)
@ -108,26 +109,31 @@ class CompressedTensorsConfig(QuantizationConfig):
return is_8_bits and is_token and is_symmetric and is_dynamic return is_8_bits and is_token and is_symmetric and is_dynamic
def _is_w4a16(self, weight_quant: BaseModel, def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None input_quant_none = input_quant is None
is_4_bits = weight_quant.num_bits == 4
is_symmetric = weight_quant.symmetric is_symmetric = weight_quant.symmetric
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_static = not weight_quant.dynamic is_static = not weight_quant.dynamic
return is_4_bits and input_quant_none and is_symmetric and is_static return (is_channel_group and input_quant_none and is_symmetric
and is_static)
def _get_schema(self, weight_quant: BaseModel, def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme": input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_w4a16(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
if self.quant_format == CompressionFormat.marlin_24.value: if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size) group_size=weight_quant.group_size)
if self.quant_format == CompressionFormat.pack_quantized.value: if (self.quant_format == CompressionFormat.pack_quantized.value
return CompressedTensorsW4A16( and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
group_size=weight_quant.group_size) group_size=weight_quant.group_size)

View File

@ -1,10 +1,11 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401 from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized) CompressedTensorsUnquantized)
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
from .compressed_tensors_w4a16_24 import ( # noqa: F401 from .compressed_tensors_w4a16_24 import ( # noqa: F401
CompressedTensorsW4A16Sparse24) W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken) CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor) CompressedTensorsW8A8StaticTensor)
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW4A16Sparse24"] __all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_BITS = [4]
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):

View File

@ -11,10 +11,11 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
marlin_permute_scales) marlin_permute_scales)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW4A16"] __all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8]
class CompressedTensorsW4A16(CompressedTensorsScheme): class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self, def __init__(self,
strategy: str, strategy: str,