Add: Support for Sparse24Bitmask Compressed Models

This commit is contained in:
Rahul Tuli 2025-02-05 15:30:43 -06:00 committed by GitHub
parent af8486de49
commit 3b2005e1db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 503 additions and 112 deletions

View File

@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null

View File

@ -3,6 +3,7 @@
Run `pytest tests/quantization/test_compressed_tensors.py`. Run `pytest tests/quantization/test_compressed_tensors.py`.
""" """
from typing import Optional from typing import Optional
import pytest import pytest
@ -22,12 +23,30 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_args", "model_args",
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor", [
QuantizationType.INT, 2560, True), (
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel", "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
QuantizationType.INT, 2560, True), "tensor",
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor", QuantizationType.INT,
QuantizationType.INT, 2560, False)]) 2560,
True,
),
(
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
"channel",
QuantizationType.INT,
2560,
True,
),
(
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
"tensor",
QuantizationType.INT,
2560,
False,
),
],
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm: with vllm_runner(model_path, enforce_eager=True) as llm:
@ -85,21 +104,31 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert output assert output
@pytest.mark.parametrize("model_path", [ @pytest.mark.parametrize(
"neuralmagic/Llama-3.2-1B-quantized.w8a8", "model_path",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", [
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", "neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
]) "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
],
)
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner, def test_compressed_tensors_w8a8_logprobs(
example_prompts, model_path, hf_runner,
max_tokens, num_logprobs): vllm_runner,
example_prompts,
model_path,
max_tokens,
num_logprobs,
):
dtype = "bfloat16" dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model # skip language translation prompt for the static per tensor asym model
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501 if (model_path ==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
): # noqa: E501
example_prompts = example_prompts[0:-1] example_prompts = example_prompts[0:-1]
with hf_runner(model_path, dtype=dtype) as hf_model: with hf_runner(model_path, dtype=dtype) as hf_model:
@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
assert output assert output
@pytest.mark.parametrize("model_args", [ @pytest.mark.parametrize(
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), "model_args",
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), [
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"), ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
"channel"), (
]) "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
"channel",
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
"channel",
),
],
)
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm: with vllm_runner(model_path, dtype=torch.float16) as llm:
@ -156,9 +193,12 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wNa16_args", "wNa16_args",
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), [
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
],
)
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -218,7 +258,8 @@ def test_compressed_tensors_fp8(vllm_runner):
CompressedTensorsLinearMethod) CompressedTensorsLinearMethod)
assert isinstance( assert isinstance(
qkv_proj.scheme, qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8)) (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
)
assert qkv_proj.input_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32
@ -241,9 +282,14 @@ def test_compressed_tensors_kv_cache(vllm_runner):
assert output assert output
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(
reason="Sparse FP8 is not yet supported on this GPU type.") not sparse_cutlass_supported(),
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): reason="Sparse FP8 is not yet supported on this GPU type.",
)
def _test_2of4_quant_models(qkv_proj,
weight_strategy,
input_strategy,
format="dense"):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24) assert isinstance(qkv_proj.scheme, CompressedTensors24)
@ -252,22 +298,39 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert qkv_proj.scheme.quantized assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").format == format
assert sparsity_map.get("Linear").sparsity_structure == "2:4" assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@pytest.mark.skipif(not current_platform.has_device_capability(90), @pytest.mark.skipif(
reason="Sparse FP8 is not yet supported on this GPU type.") not current_platform.has_device_capability(90),
@pytest.mark.parametrize("args_2of4", [ reason="Sparse FP8 is not yet supported on this GPU type.",
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", )
"token"), @pytest.mark.parametrize(
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", "args_2of4",
"channel", "tensor"), [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor", (
"tensor"), "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", "channel",
"tensor", "token"), "token",
]) ),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4 model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -286,16 +349,134 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
assert output assert output
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(
reason="Sparse FP8 is not yet supported on this GPU type.") not current_platform.has_device_capability(90),
@pytest.mark.parametrize("args_2of4", [ reason="Sparse FP8 is not yet supported on this GPU type.",
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", )
"channel", "token"), @pytest.mark.parametrize(
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor", "args_2of4",
"tensor"), [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", (
"tensor", "token"), "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
]) "channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4 model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -317,10 +498,12 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.") @pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif( @pytest.mark.skipif(
not sparse_cutlass_supported(), not sparse_cutlass_supported(),
reason="2of4 Sparse is not yet supported on this GPU type.") reason="2of4 Sparse is not yet supported on this GPU type.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args_2of4", "args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4 model = args_2of4
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -337,7 +520,9 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
assert qkv_proj.scheme.input_quant is None assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 sparsity_map = (
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
) # noqa: E501
assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4" assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@ -346,3 +531,38 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
output = llm.generate_greedy("Hello my name is", max_tokens=20) output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output) print(output)
assert output assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = (
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
) # noqa: E501
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output

View File

@ -417,15 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig):
return None return None
# Have a valid sparsity scheme # Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel # Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None model_compression_config = (None if sparsity_scheme is None
or input_quant is not None, or sparsity_scheme.format == "dense"
weight_quant=weight_quant, else self.config)
input_quant=input_quant)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=model_compression_config,
)
elif weight_quant is None: elif weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is " logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. " "not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod") "Falling back to UnquantizedLinearMethod")
return None return None
else: else:
# Find the quant_scheme # Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore scheme = self._get_scheme_from_parts( # type: ignore
@ -475,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig):
:return: True if the layer is supported by the Cutlass 2:4 Kernel :return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise False otherwise
""" """
is_valid_sparsity = (sparsity_scheme is not None if sparsity_scheme is None:
and sparsity_scheme.sparsity_structure return False
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense") is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure ==
SparsityStructure.TWO_FOUR.value)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}
is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)
if not is_valid_sparsity: if not is_valid_sparsity:
return False return False

View File

@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs, from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy, QuantizationStrategy,
QuantizationType) QuantizationType)
from compressed_tensors.utils import combine_shards
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -22,26 +26,39 @@ __all__ = ["CompressedTensors24"]
class CompressedTensors24(CompressedTensorsScheme): class CompressedTensors24(CompressedTensorsScheme):
def __init__(self, def __init__(
quantized: bool = False, self,
weight_quant: Optional[QuantizationArgs] = None, quantized: bool = False,
input_quant: Optional[QuantizationArgs] = None): weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
):
self.quantized = quantized self.quantized = quantized
self.weight_quant = weight_quant self.weight_quant = weight_quant
self.input_quant = input_quant self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far # Only cutlass 3.x kernels are implemented so far
return 90 return 90
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(
output_partition_sizes: List[int], self,
input_size_per_partition: int, layer: torch.nn.Module,
params_dtype: torch.dtype, weight_loader: Callable, input_size: int,
**kwargs): output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
if not sparse_cutlass_supported(): if not sparse_cutlass_supported():
raise ValueError( raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with " "Sparse CUTLASS not supported. vLLM must be built with "
@ -49,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme):
self.output_dtype = params_dtype self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight # parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty( weight = ModelWeightParameter(
sum(output_partition_sizes), data=torch.empty(
input_size_per_partition, sum(output_partition_sizes),
dtype=self.weights_dtype), input_size_per_partition,
input_dim=1, dtype=self.weights_dtype,
output_dim=0, ),
weight_loader=weight_loader) input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.do_sparse_decompress:
assert all(partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape = BasevLLMParameter(
data=torch.empty(2, 1, dtype=torch.int64),
weight_loader=weight_loader,
)
compressed_weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
bitmask = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)
# Check if quantized, not just 2:4 Sparse # Check if quantized, not just 2:4 Sparse
if self.quantized: if self.quantized:
@ -68,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme):
data=torch.empty((sum(output_partition_sizes), 1), data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader,
)
else: else:
assert (self.weight_quant and self.weight_quant.strategy assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value) == QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter( weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), data=torch.empty(len(output_partition_sizes),
dtype=torch.float32), dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
@ -84,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme):
# register input quant scale # register input quant scale
assert (self.input_quant.strategy == assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value) QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty( input_scale = BasevLLMParameter(
1, dtype=torch.float32), data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
@ -107,13 +167,25 @@ class CompressedTensors24(CompressedTensorsScheme):
""" """
Compress weights after loading. Store compressed weight and meta Compress weights after loading. Store compressed weight and meta
tensor tensor
:post-condition: layer.w_compressed and layer.meta are :post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed :param layer: The layer with the weights to be processed
""" """
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del layer.compressed
del layer.bitmask
# torch.compile workaround # torch.compile workaround
if hasattr(layer, "input_scale"): if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data, layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
@ -121,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme):
if self.weight_quant: if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( layer.weight_scale = torch.nn.Parameter(
weight_scale=layer.weight_scale, convert_to_channelwise(
logical_widths=layer.logical_widths), weight_scale=layer.weight_scale,
requires_grad=False) logical_widths=layer.logical_widths,
),
requires_grad=False,
)
else: else:
# torch.compile workaround # torch.compile workaround
layer.weight_scale = torch.nn.Parameter( layer.weight_scale = torch.nn.Parameter(
@ -134,20 +209,22 @@ class CompressedTensors24(CompressedTensorsScheme):
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(self, def apply_weights(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
""" """
Returns the output tensor for the layer with 2:4 Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor sparse compressed weights, given the input tensor
and bias and bias
:param layer: The layer with 2:4 sparse compressed :param layer: The layer with 2:4 sparse compressed
weights to be used for the computation weights to be used for the computation
:param x: The input tensor to the layer :param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor :param bias: The bias to be added to the output tensor
:return: The output tensor of the layer :return: The output tensor of the layer
""" """
if self.quantized: if self.quantized:
scale = None scale = None
@ -171,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale = layer.input_scale input_scale = layer.input_scale
q_input = x q_input = x
out = ops.cutlass_scaled_sparse_mm(a=q_input, out = ops.cutlass_scaled_sparse_mm(
bt_nzs=layer.weight, a=q_input,
bt_meta=layer.meta, bt_nzs=layer.weight,
scale_a=input_scale, bt_meta=layer.meta,
scale_b=layer.weight_scale, scale_a=input_scale,
out_dtype=self.output_dtype, scale_b=layer.weight_scale,
bias=bias) out_dtype=self.output_dtype,
bias=bias,
)
assert out.is_contiguous() assert out.is_contiguous()
return out return out
@ -203,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme):
raise ValueError("Quantization type not supported by Cutlass") raise ValueError("Quantization type not supported by Cutlass")
def _decompress_bitmask_compressed_weight(
self,
compressed: torch.Tensor,
bitmask: torch.Tensor,
layer: torch.nn.Module,
) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
def check_24(tensor): This function also supports sharded decompression.
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1) :param compressed: The 2:4 sparse weight tensor compressed using the
return (zero_counts >= 2).all().item() sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor = self.model_compressor.sparsity_compressor
def _process_split(
bitmask_compressed_weight: torch.Tensor,
shape,
bitmask: torch.Tensor,
) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(
compressed=compressed,
shape=(
layer.logical_widths[0],
layer.input_size_per_partition,
),
bitmask=bitmask,
))
return decompressed