Fix target matching for fused layers with compressed-tensors (#12617)
Without this PR --------------- Quantizing models with llm-compressor and a recipe that explicitly lists names of layers produces a model that is not loadable by vLLM (i.e. `vllm serve <model>` fails with `raise ValueError(f"Unable to find matching target for {module} in the ...`). Example recipe: ``` recipe = """ quantization_stage: run_type: oneshot quantization_modifiers: GPTQModifier: ignore: ["lm_head"] config_groups: group_0: weights: num_bits: 4 type: "int" symmetric: true strategy: "group" group_size: 128 targets: [ "model.layers.0.mlp.down_proj", "model.layers.2.mlp.down_proj", "model.layers.3.mlp.down_proj", "model.layers.4.mlp.down_proj", "model.layers.5.mlp.down_proj", "model.layers.6.mlp.down_proj", "model.layers.7.mlp.down_proj", "model.layers.8.mlp.down_proj", "model.layers.9.mlp.down_proj", "model.layers.10.mlp.down_proj", "model.layers.11.mlp.down_proj", "model.layers.12.mlp.down_proj", "model.layers.13.mlp.down_proj", "model.layers.14.mlp.down_proj", "model.layers.15.mlp.down_proj", "model.layers.16.mlp.down_proj", "model.layers.17.mlp.down_proj", "model.layers.19.mlp.down_proj", "model.layers.21.mlp.down_proj", "model.layers.22.mlp.down_proj", . . . ] """ ``` To reproduce the vLLM error: ```bash vllm serve nm-testing/eldar-test ``` With this PR ------------ Models are loaded correctly without any errors.
This commit is contained in:
parent
cb3e73e4c8
commit
1867c258bd
@ -103,7 +103,8 @@ def find_matched_target(layer_name: Optional[str], module: Module,
|
|||||||
|
|
||||||
matched_target = (_find_first_match(layer_name, targets)
|
matched_target = (_find_first_match(layer_name, targets)
|
||||||
or _find_first_match(module.__class__.__name__, targets,
|
or _find_first_match(module.__class__.__name__, targets,
|
||||||
True))
|
True)
|
||||||
|
or _match_fused_layer(layer_name, targets))
|
||||||
|
|
||||||
if matched_target is None:
|
if matched_target is None:
|
||||||
raise ValueError(f"Unable to find matching target for {module} in the "
|
raise ValueError(f"Unable to find matching target for {module} in the "
|
||||||
@ -152,3 +153,41 @@ def _is_equal_or_regex_match(value: str,
|
|||||||
elif target == value:
|
elif target == value:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _match_fused_layer(layer_name: str,
|
||||||
|
target_layers: Iterable[str]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Match a fused layer name to its corresponding individual layer in
|
||||||
|
target_layers.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||||
|
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||||
|
"model.layers.0.self_attn.k_proj",
|
||||||
|
"model.layers.0.self_attn.v_proj"]
|
||||||
|
"""
|
||||||
|
# Split into parent path and layer type
|
||||||
|
# e.g., "model.layers.0.self_attn" and "qkv_proj"
|
||||||
|
parent_path = ".".join(layer_name.split(".")[:-1])
|
||||||
|
layer_type = layer_name.split(".")[-1]
|
||||||
|
|
||||||
|
if layer_type not in FUSED_LAYER_NAME_MAPPING:
|
||||||
|
return None
|
||||||
|
|
||||||
|
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]
|
||||||
|
|
||||||
|
# Look for a target layer that:
|
||||||
|
# 1. Has the same parent path
|
||||||
|
# 2. Ends with one of the possible individual layer types
|
||||||
|
for target in target_layers:
|
||||||
|
is_same_parent = parent_path in target
|
||||||
|
is_matching_type = any(type_suffix in target
|
||||||
|
for type_suffix in possible_layer_types)
|
||||||
|
|
||||||
|
if is_same_parent and is_matching_type and all(
|
||||||
|
'.'.join([parent_path, type_suffix])
|
||||||
|
for type_suffix in possible_layer_types):
|
||||||
|
return target
|
||||||
|
|
||||||
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user