Add: Support for Sparse24Bitmask Compressed Models
This commit is contained in:
parent
af8486de49
commit
3b2005e1db
@ -0,0 +1,11 @@
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
|
||||
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.6353
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.637
|
||||
limit: null
|
||||
num_fewshot: null
|
@ -3,6 +3,7 @@
|
||||
|
||||
Run `pytest tests/quantization/test_compressed_tensors.py`.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -22,12 +23,30 @@ from vllm.platforms import current_platform
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_args",
|
||||
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
|
||||
QuantizationType.INT, 2560, True),
|
||||
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
|
||||
QuantizationType.INT, 2560, True),
|
||||
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
|
||||
QuantizationType.INT, 2560, False)])
|
||||
[
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
"tensor",
|
||||
QuantizationType.INT,
|
||||
2560,
|
||||
True,
|
||||
),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
|
||||
"channel",
|
||||
QuantizationType.INT,
|
||||
2560,
|
||||
True,
|
||||
),
|
||||
(
|
||||
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
|
||||
"tensor",
|
||||
QuantizationType.INT,
|
||||
2560,
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
|
||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||
@ -85,21 +104,31 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
|
||||
])
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
|
||||
example_prompts, model_path,
|
||||
max_tokens, num_logprobs):
|
||||
def test_compressed_tensors_w8a8_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_path,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
):
|
||||
dtype = "bfloat16"
|
||||
|
||||
# skip language translation prompt for the static per tensor asym model
|
||||
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501
|
||||
if (model_path ==
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
|
||||
): # noqa: E501
|
||||
example_prompts = example_prompts[0:-1]
|
||||
|
||||
with hf_runner(model_path, dtype=dtype) as hf_model:
|
||||
@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_args", [
|
||||
@pytest.mark.parametrize(
|
||||
"model_args",
|
||||
[
|
||||
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
|
||||
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
|
||||
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
|
||||
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
|
||||
"channel"),
|
||||
])
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
|
||||
"channel",
|
||||
),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
|
||||
"channel",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
||||
model_path, strategy = model_args
|
||||
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
||||
@ -156,9 +193,12 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wNa16_args",
|
||||
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
|
||||
[
|
||||
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
|
||||
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
|
||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
|
||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||
model, strategy, group, pack_factor = wNa16_args
|
||||
with vllm_runner(model) as llm:
|
||||
@ -218,7 +258,8 @@ def test_compressed_tensors_fp8(vllm_runner):
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(
|
||||
qkv_proj.scheme,
|
||||
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
|
||||
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
|
||||
)
|
||||
|
||||
assert qkv_proj.input_scale.dtype is torch.float32
|
||||
|
||||
@ -241,9 +282,14 @@ def test_compressed_tensors_kv_cache(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
def _test_2of4_quant_models(qkv_proj,
|
||||
weight_strategy,
|
||||
input_strategy,
|
||||
format="dense"):
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
@ -252,22 +298,39 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
|
||||
assert qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "dense"
|
||||
assert sparsity_map.get("Linear").format == format
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize("args_2of4", [
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
|
||||
"token"),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
|
||||
"channel", "tensor"),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
|
||||
"tensor"),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor", "token"),
|
||||
])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
|
||||
"channel",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
@ -286,16 +349,134 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize("args_2of4", [
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
|
||||
"channel", "token"),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
|
||||
"tensor"),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor", "token"),
|
||||
])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
|
||||
"channel",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
||||
_test_2of4_quant_models(
|
||||
qkv_proj,
|
||||
weight_strategy,
|
||||
input_strategy,
|
||||
format="sparse-24-bitmask",
|
||||
)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="cutlass is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
|
||||
"channel",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.int8
|
||||
_test_2of4_quant_models(
|
||||
qkv_proj,
|
||||
weight_strategy,
|
||||
input_strategy,
|
||||
format="sparse-24-bitmask",
|
||||
)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
|
||||
"channel",
|
||||
"token",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
|
||||
"tensor",
|
||||
"tensor",
|
||||
),
|
||||
(
|
||||
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor",
|
||||
"token",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
@ -317,10 +498,12 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="2of4 Sparse is not yet supported on this GPU type.")
|
||||
reason="2of4 Sparse is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
|
||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
|
||||
)
|
||||
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||
model = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
@ -337,7 +520,9 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||
assert qkv_proj.scheme.input_quant is None
|
||||
assert not qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
sparsity_map = (
|
||||
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
) # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "dense"
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
@ -346,3 +531,38 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Cutlass is not yet supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
|
||||
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
|
||||
model = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
assert qkv_proj.scheme.weight_quant is None
|
||||
assert qkv_proj.scheme.input_quant is None
|
||||
assert not qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = (
|
||||
qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
) # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
@ -417,15 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return None
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
scheme = CompressedTensors24(quantized=weight_quant is not None
|
||||
or input_quant is not None,
|
||||
model_compression_config = (None if sparsity_scheme is None
|
||||
or sparsity_scheme.format == "dense"
|
||||
else self.config)
|
||||
|
||||
scheme = CompressedTensors24(
|
||||
quantized=weight_quant is not None or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant)
|
||||
input_quant=input_quant,
|
||||
model_compression_config=model_compression_config,
|
||||
)
|
||||
elif weight_quant is None:
|
||||
logger.warning_once("Acceleration for non-quantized schemes is "
|
||||
"not supported by Compressed Tensors. "
|
||||
"Falling back to UnquantizedLinearMethod")
|
||||
return None
|
||||
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
@ -475,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||
False otherwise
|
||||
"""
|
||||
is_valid_sparsity = (sparsity_scheme is not None
|
||||
and sparsity_scheme.sparsity_structure
|
||||
== SparsityStructure.TWO_FOUR.value
|
||||
and sparsity_scheme.format == "dense")
|
||||
if sparsity_scheme is None:
|
||||
return False
|
||||
|
||||
is_valid_sparsity_structure: bool = (
|
||||
sparsity_scheme.sparsity_structure ==
|
||||
SparsityStructure.TWO_FOUR.value)
|
||||
|
||||
valid_compressors = {
|
||||
CompressionFormat.dense.value,
|
||||
CompressionFormat.sparse_24_bitmask.value
|
||||
}
|
||||
|
||||
is_valid_sparsity = (is_valid_sparsity_structure
|
||||
and sparsity_scheme.format in valid_compressors)
|
||||
|
||||
if not is_valid_sparsity:
|
||||
return False
|
||||
|
||||
|
@ -1,13 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from compressed_tensors.utils import combine_shards
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -22,26 +26,39 @@ __all__ = ["CompressedTensors24"]
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None):
|
||||
|
||||
input_quant: Optional[QuantizationArgs] = None,
|
||||
model_compression_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
self.model_compressor = (
|
||||
ModelCompressor.from_compression_config(model_compression_config)
|
||||
if model_compression_config is not None else None)
|
||||
self.do_sparse_decompress = (
|
||||
self.model_compressor is not None
|
||||
and self.model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
@ -49,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
self.output_dtype = params_dtype
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype),
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
assert all(partition_size % 8 == 0
|
||||
for partition_size in output_partition_sizes
|
||||
), "All partitions must be divisible by 8 for "
|
||||
"2:4 sparse compressed models"
|
||||
|
||||
shape = BasevLLMParameter(
|
||||
data=torch.empty(2, 1, dtype=torch.int64),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
compressed_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
bitmask = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 8,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("shape", shape)
|
||||
layer.register_parameter("compressed", compressed_weight)
|
||||
layer.register_parameter("bitmask", bitmask)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
@ -68,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
@ -84,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
# register input quant scale
|
||||
assert (self.input_quant.strategy ==
|
||||
QuantizationStrategy.TENSOR.value)
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
@ -114,6 +174,18 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
if self.do_sparse_decompress:
|
||||
layer.weight.data = self._decompress_bitmask_compressed_weight(
|
||||
compressed=layer.compressed,
|
||||
bitmask=layer.bitmask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# compressed and bitmask tensors
|
||||
# are no longer needed after decompression
|
||||
del layer.compressed
|
||||
del layer.bitmask
|
||||
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
@ -121,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths),
|
||||
requires_grad=False)
|
||||
logical_widths=layer.logical_widths,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
@ -134,10 +209,12 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
@ -171,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a=q_input,
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=self.output_dtype,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
)
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
@ -203,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
def _decompress_bitmask_compressed_weight(
|
||||
self,
|
||||
compressed: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
|
||||
return the result.
|
||||
|
||||
def check_24(tensor):
|
||||
new_tensor = tensor.view(-1, 4)
|
||||
zero_counts = (new_tensor == 0).sum(dim=1)
|
||||
return (zero_counts >= 2).all().item()
|
||||
This function also supports sharded decompression.
|
||||
|
||||
:param compressed: The 2:4 sparse weight tensor compressed using the
|
||||
sparse-24-bitmask compressor. This is different from
|
||||
`cutlass_sparse_compress` which uses a different scheme (2 bits for
|
||||
every nonzero element that represent the coordinate within the block
|
||||
of 4). The bitmask compression here uses a bitmask to indicate the
|
||||
positions of non-zero elements.
|
||||
:param bitmask: The 2:4 bitmask associated with the compressed weights,
|
||||
representing the positions of non-zero elements in the compressed
|
||||
tensor.
|
||||
:param layer: The layer whose weights need to be processed after
|
||||
loading.
|
||||
:return: The decompressed 2:4 sparse weight tensor.
|
||||
"""
|
||||
|
||||
sparsity_compressor = self.model_compressor.sparsity_compressor
|
||||
|
||||
def _process_split(
|
||||
bitmask_compressed_weight: torch.Tensor,
|
||||
shape,
|
||||
bitmask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
weight_data = dict(
|
||||
compressed=bitmask_compressed_weight,
|
||||
shape=shape,
|
||||
bitmask=bitmask,
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: List[torch.Tensor] = []
|
||||
split_bitmask: List[torch.Tensor] = []
|
||||
split_shape: List[Tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
split_bitmask = torch.split(bitmask, layer.logical_widths)
|
||||
split_shape = [(out, layer.input_size_per_partition)
|
||||
for out in layer.logical_widths]
|
||||
|
||||
if split_weights:
|
||||
decompressed_shards = [
|
||||
_process_split(compressed_weight, shape, bitmask)
|
||||
for compressed_weight, shape, bitmask in zip(
|
||||
split_weights, split_shape, split_bitmask)
|
||||
]
|
||||
decompressed = combine_shards(decompressed_shards)
|
||||
else:
|
||||
decompressed = sparsity_compressor.decompress_weight(
|
||||
dict(
|
||||
compressed=compressed,
|
||||
shape=(
|
||||
layer.logical_widths[0],
|
||||
layer.input_size_per_partition,
|
||||
),
|
||||
bitmask=bitmask,
|
||||
))
|
||||
return decompressed
|
||||
|
Loading…
x
Reference in New Issue
Block a user