Fix: Respect sparsity_config.ignore
in Cutlass Integration (#12517)
This PR addresses a bug in the Cutlass integration where the `sparsity_config.ignore` list was not being respected. When only a subset of modules were configured as Sparse24, the system incorrectly selected Cutlass for non-sparse modules as well. This update ensures the correct scheme is selected for non-sparse modules, fixing this behavior. --- ### Changes - Updated logic to correctly respect `sparsity_config.ignore`. - Ensured non-sparse modules use the appropriate scheme instead of defaulting to Cutlass. --- <details> <summary>Testing Setup</summary> The fix has been tested on top of [this diff](https://github.com/vllm-project/vllm/pull/12097). #### Steps to Test: ```bash git checkout -b my-test-branch origin/rahul-bitmask-additions # compressed Cutlass support git revert --no-edit aa2cd2c # revert Tyler's commit to turn off Cutlass for W16A16 git cherry-pick ca624cddb # this branch ``` #### Additional Patch Required: ```diff diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a54177c1c..f916dd0c9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -9,7 +9,7 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) from pydantic import BaseModel - +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform - +logger = init_logger(__name__) __all__ = ["CompressedTensorsLinearMethod"] SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" ``` Apply using: ```bash git apply logging-patch.patch ``` </details> --- <details> <summary>Models Tested</summary> - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-partial-24` - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-full-sparse24` - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-partial-24-entire-fp8-compressed` - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-partial-24-remaining-fp8-compressed` </details> --- <details> <summary>Example Output</summary> #### Layers 0-5 (Sparse24) ``` Using scheme: CompressedTensors24 for model.layers.0.self_attn.qkv_proj Using scheme: CompressedTensors24 for model.layers.0.self_attn.o_proj Using scheme: CompressedTensors24 for model.layers.0.mlp.gate_up_proj Using scheme: CompressedTensors24 for model.layers.0.mlp.down_proj ... ``` #### Layers 6+ (Non-Sparse, FP8) ``` Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.self_attn.qkv_proj Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.self_attn.o_proj Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.mlp.gate_up_proj Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.mlp.down_proj ... ``` </details> **Note:** Assumed all modules in fused layers such as `QKV_proj` and `Gate_up_proj` follow the same quantization/pruning scheme. --- For related tasks using the Asana app for GitHub, refer to [[this link](https://app.asana.com/0/0/1209227810815160)](https://app.asana.com/0/0/1209227810815160). Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
This commit is contained in:
parent
cfa134d247
commit
3e1c76cf3a
@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, cast
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
@ -44,6 +45,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
ignore: List[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@ -54,6 +56,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
@ -98,7 +101,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
|
||||
return cls(
|
||||
@ -106,20 +109,23 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _sparsity_scheme_map_from_config(
|
||||
cls, config: Dict[str,
|
||||
Any]) -> Dict[str, SparsityCompressionConfig]:
|
||||
def _parse_sparsity_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
sparsity compression configurations
|
||||
:return: A tuple with two elements
|
||||
1. A dictionary mapping target layer names to their corresponding
|
||||
sparsity_config
|
||||
2. A list of layer names to ignore for sparsity
|
||||
"""
|
||||
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
||||
return dict()
|
||||
return dict(), []
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
@ -127,7 +133,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
return sparse_scheme_map
|
||||
sparsity_ignore_list = sparsity_config.ignore or list()
|
||||
return sparse_scheme_map, sparsity_ignore_list
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
@ -352,7 +359,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
ignore: List of layer_names or nn.Module names to be ignored.
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
@ -370,6 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# need to make accelerate optional in ct to do this
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
@ -379,19 +387,24 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
elif self.sparsity_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_scheme_map.keys())
|
||||
weight_quant = None
|
||||
input_quant = None
|
||||
|
||||
# For models with sparsity, assumes that the sparse layers are also
|
||||
# quantized for cutlass 2:4 support
|
||||
sparsity_scheme: Optional[
|
||||
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
|
||||
matched_target)
|
||||
if self.sparsity_scheme_map:
|
||||
is_ignored = False
|
||||
with suppress(ValueError):
|
||||
is_ignored = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_ignore_list)
|
||||
|
||||
# if the layer is in the sparsity ignore list,
|
||||
# we should not apply any sparsity scheme
|
||||
|
||||
if not is_ignored:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_scheme_map.keys())
|
||||
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
|
||||
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
@ -419,6 +432,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
|
||||
layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
|
@ -12,7 +12,7 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value
|
||||
CompressionFormat.float_quantized.value,
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
@ -77,6 +77,53 @@ def check_equal_or_regex_match(layer_name: str,
|
||||
return False
|
||||
|
||||
|
||||
def _handle_fused_layers(func):
|
||||
"""
|
||||
Decorator to handle fused layers by mapping vllm fused layer names
|
||||
to their corresponding unfused layer names for quantization/pruning schemes.
|
||||
"""
|
||||
# fused_layer_name -> unfused_layer_name
|
||||
fused_layer_map = {
|
||||
"qkv_proj": "q_proj",
|
||||
"gate_up_proj": "up_proj",
|
||||
}
|
||||
|
||||
def fused_layer_handler(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> Optional[str]:
|
||||
"""
|
||||
Wrapper function specifically designed to support the
|
||||
find_matched_target function.
|
||||
|
||||
It handles cases where the provided layer name corresponds to a
|
||||
fused layer in vllm, mapping it to its equivalent unfused layer name
|
||||
based on the predefined fused_layer_map. If the original layer name
|
||||
raises a ValueError in the wrapped function, this handler
|
||||
will attempt to resolve the issue by substituting with unfused
|
||||
layer name.
|
||||
|
||||
:param layer_name: Name of the layer, which may be fused.
|
||||
:param module: An instance of torch.nn.Module.
|
||||
:param targets: A list of target names or patterns to match.
|
||||
:return: The result of the wrapped find_matched_target function with
|
||||
the resolved layer name.
|
||||
:raises ValueError: If the layer name cannot be resolved to a
|
||||
valid target.
|
||||
"""
|
||||
try:
|
||||
return func(layer_name, module, targets)
|
||||
except ValueError:
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
|
||||
unfused_proj_name = fused_layer_map.get(fused_proj_name,
|
||||
fused_proj_name)
|
||||
new_layer_name = f"{parent_name}.{unfused_proj_name}"
|
||||
return func(new_layer_name, module, targets)
|
||||
|
||||
return fused_layer_handler
|
||||
|
||||
|
||||
@_handle_fused_layers
|
||||
def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> str:
|
||||
"""
|
||||
@ -107,8 +154,9 @@ def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
or _match_fused_layer(layer_name, targets))
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(f"Unable to find matching target for {module} in the "
|
||||
"compressed-tensors config.")
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config.")
|
||||
|
||||
return matched_target
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user