Add: Support for Sparse24Bitmask Compressed Models

This commit is contained in:
Rahul Tuli 2025-02-05 15:30:43 -06:00 committed by GitHub
parent af8486de49
commit 3b2005e1db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 503 additions and 112 deletions

View File

@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null

View File

@ -3,6 +3,7 @@
Run `pytest tests/quantization/test_compressed_tensors.py`. Run `pytest tests/quantization/test_compressed_tensors.py`.
""" """
from typing import Optional from typing import Optional
import pytest import pytest
@ -22,12 +23,30 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_args", "model_args",
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor", [
QuantizationType.INT, 2560, True), (
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel", "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
QuantizationType.INT, 2560, True), "tensor",
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor", QuantizationType.INT,
QuantizationType.INT, 2560, False)]) 2560,
True,
),
(
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
"channel",
QuantizationType.INT,
2560,
True,
),
(
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
"tensor",
QuantizationType.INT,
2560,
False,
),
],
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm: with vllm_runner(model_path, enforce_eager=True) as llm:
@ -85,21 +104,31 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert output assert output
@pytest.mark.parametrize("model_path", [ @pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8", "neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
]) ],
)
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner, def test_compressed_tensors_w8a8_logprobs(
example_prompts, model_path, hf_runner,
max_tokens, num_logprobs): vllm_runner,
example_prompts,
model_path,
max_tokens,
num_logprobs,
):
dtype = "bfloat16" dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model # skip language translation prompt for the static per tensor asym model
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501 if (model_path ==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
): # noqa: E501
example_prompts = example_prompts[0:-1] example_prompts = example_prompts[0:-1]
with hf_runner(model_path, dtype=dtype) as hf_model: with hf_runner(model_path, dtype=dtype) as hf_model:
@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
assert output assert output
@pytest.mark.parametrize("model_args", [ @pytest.mark.parametrize(
"model_args",
[
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"), (
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
"channel"), "channel",
]) ),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
"channel",
),
],
)
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm: with vllm_runner(model_path, dtype=torch.float16) as llm:
@ -156,9 +193,12 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wNa16_args", "wNa16_args",
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), [
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
],
)
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -218,7 +258,8 @@ def test_compressed_tensors_fp8(vllm_runner):
CompressedTensorsLinearMethod) CompressedTensorsLinearMethod)
assert isinstance( assert isinstance(
qkv_proj.scheme, qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8)) (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
)
assert qkv_proj.input_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32
@ -241,9 +282,14 @@ def test_compressed_tensors_kv_cache(vllm_runner):
assert output assert output
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(
reason="Sparse FP8 is not yet supported on this GPU type.") not sparse_cutlass_supported(),
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): reason="Sparse FP8 is not yet supported on this GPU type.",
)
def _test_2of4_quant_models(qkv_proj,
weight_strategy,
input_strategy,
format="dense"):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24) assert isinstance(qkv_proj.scheme, CompressedTensors24)
@ -252,22 +298,39 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert qkv_proj.scheme.quantized assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").format == format
assert sparsity_map.get("Linear").sparsity_structure == "2:4" assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@pytest.mark.skipif(not current_platform.has_device_capability(90), @pytest.mark.skipif(
reason="Sparse FP8 is not yet supported on this GPU type.") not current_platform.has_device_capability(90),
@pytest.mark.parametrize("args_2of4", [ reason="Sparse FP8 is not yet supported on this GPU type.",
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", )
"token"), @pytest.mark.parametrize(
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", "args_2of4",
"channel", "tensor"), [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor", (
"tensor"), "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", "channel",
"tensor", "token"), "token",
]) ),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4 model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -286,16 +349,134 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
assert output assert output
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(
reason="Sparse FP8 is not yet supported on this GPU type.") not current_platform.has_device_capability(90),
@pytest.mark.parametrize("args_2of4", [ reason="Sparse FP8 is not yet supported on this GPU type.",
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", )
"channel", "token"), @pytest.mark.parametrize(
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor", "args_2of4",
"tensor"), [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", (
"tensor", "token"), "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
]) "channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4 model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -317,10 +498,12 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.") @pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif( @pytest.mark.skipif(
not sparse_cutlass_supported(), not sparse_cutlass_supported(),
reason="2of4 Sparse is not yet supported on this GPU type.") reason="2of4 Sparse is not yet supported on this GPU type.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args_2of4", "args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4 model = args_2of4
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
@ -337,7 +520,9 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
assert qkv_proj.scheme.input_quant is None assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 sparsity_map = (
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
) # noqa: E501
assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4" assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@ -346,3 +531,38 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
output = llm.generate_greedy("Hello my name is", max_tokens=20) output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output) print(output)
assert output assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = (
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
) # noqa: E501
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output

View File

@ -417,15 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig):
return None return None
# Have a valid sparsity scheme # Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel # Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None model_compression_config = (None if sparsity_scheme is None
or input_quant is not None, or sparsity_scheme.format == "dense"
else self.config)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant, weight_quant=weight_quant,
input_quant=input_quant) input_quant=input_quant,
model_compression_config=model_compression_config,
)
elif weight_quant is None: elif weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is " logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. " "not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod") "Falling back to UnquantizedLinearMethod")
return None return None
else: else:
# Find the quant_scheme # Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore scheme = self._get_scheme_from_parts( # type: ignore
@ -475,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig):
:return: True if the layer is supported by the Cutlass 2:4 Kernel :return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise False otherwise
""" """
is_valid_sparsity = (sparsity_scheme is not None if sparsity_scheme is None:
and sparsity_scheme.sparsity_structure return False
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense") is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure ==
SparsityStructure.TWO_FOUR.value)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}
is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)
if not is_valid_sparsity: if not is_valid_sparsity:
return False return False

View File

@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs, from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy, QuantizationStrategy,
QuantizationType) QuantizationType)
from compressed_tensors.utils import combine_shards
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -22,26 +26,39 @@ __all__ = ["CompressedTensors24"]
class CompressedTensors24(CompressedTensorsScheme): class CompressedTensors24(CompressedTensorsScheme):
def __init__(self, def __init__(
self,
quantized: bool = False, quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None, weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None): input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
):
self.quantized = quantized self.quantized = quantized
self.weight_quant = weight_quant self.weight_quant = weight_quant
self.input_quant = input_quant self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far # Only cutlass 3.x kernels are implemented so far
return 90 return 90
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: List[int], output_partition_sizes: List[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype,
**kwargs): weight_loader: Callable,
**kwargs,
):
if not sparse_cutlass_supported(): if not sparse_cutlass_supported():
raise ValueError( raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with " "Sparse CUTLASS not supported. vLLM must be built with "
@ -49,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme):
self.output_dtype = params_dtype self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight # parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty( weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
dtype=self.weights_dtype), dtype=self.weights_dtype,
),
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader,
)
if self.do_sparse_decompress:
assert all(partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape = BasevLLMParameter(
data=torch.empty(2, 1, dtype=torch.int64),
weight_loader=weight_loader,
)
compressed_weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
bitmask = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)
# Check if quantized, not just 2:4 Sparse # Check if quantized, not just 2:4 Sparse
if self.quantized: if self.quantized:
@ -68,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme):
data=torch.empty((sum(output_partition_sizes), 1), data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader,
)
else: else:
assert (self.weight_quant and self.weight_quant.strategy assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value) == QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter( weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), data=torch.empty(len(output_partition_sizes),
dtype=torch.float32), dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
@ -84,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme):
# register input quant scale # register input quant scale
assert (self.input_quant.strategy == assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value) QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty( input_scale = BasevLLMParameter(
1, dtype=torch.float32), data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
@ -114,6 +174,18 @@ class CompressedTensors24(CompressedTensorsScheme):
:param layer: The layer with the weights to be processed :param layer: The layer with the weights to be processed
""" """
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del layer.compressed
del layer.bitmask
# torch.compile workaround # torch.compile workaround
if hasattr(layer, "input_scale"): if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data, layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
@ -121,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme):
if self.weight_quant: if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( layer.weight_scale = torch.nn.Parameter(
convert_to_channelwise(
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths), logical_widths=layer.logical_widths,
requires_grad=False) ),
requires_grad=False,
)
else: else:
# torch.compile workaround # torch.compile workaround
layer.weight_scale = torch.nn.Parameter( layer.weight_scale = torch.nn.Parameter(
@ -134,10 +209,12 @@ class CompressedTensors24(CompressedTensorsScheme):
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(self, def apply_weights(
self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
""" """
Returns the output tensor for the layer with 2:4 Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor sparse compressed weights, given the input tensor
@ -171,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale = layer.input_scale input_scale = layer.input_scale
q_input = x q_input = x
out = ops.cutlass_scaled_sparse_mm(a=q_input, out = ops.cutlass_scaled_sparse_mm(
a=q_input,
bt_nzs=layer.weight, bt_nzs=layer.weight,
bt_meta=layer.meta, bt_meta=layer.meta,
scale_a=input_scale, scale_a=input_scale,
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=self.output_dtype, out_dtype=self.output_dtype,
bias=bias) bias=bias,
)
assert out.is_contiguous() assert out.is_contiguous()
return out return out
@ -203,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme):
raise ValueError("Quantization type not supported by Cutlass") raise ValueError("Quantization type not supported by Cutlass")
def _decompress_bitmask_compressed_weight(
self,
compressed: torch.Tensor,
bitmask: torch.Tensor,
layer: torch.nn.Module,
) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
def check_24(tensor): This function also supports sharded decompression.
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1) :param compressed: The 2:4 sparse weight tensor compressed using the
return (zero_counts >= 2).all().item() sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor = self.model_compressor.sparsity_compressor
def _process_split(
bitmask_compressed_weight: torch.Tensor,
shape,
bitmask: torch.Tensor,
) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(
compressed=compressed,
shape=(
layer.logical_widths[0],
layer.input_size_per_partition,
),
bitmask=bitmask,
))
return decompressed