[Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_mapping (#11924)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
87054a57ab
commit
a3a3ee4e6f
@ -39,7 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
|
||||
serialize_vllm_model, tensorizer_weights_iterator)
|
||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
from vllm.model_executor.model_loader.utils import (ParamMapping,
|
||||
get_model_architecture,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
@ -983,21 +984,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
|
||||
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
|
||||
# packed_modules_mapping.
|
||||
inverse_stacked_mapping: Dict[str, List[str]] = {}
|
||||
for orig, (
|
||||
packed,
|
||||
idx,
|
||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||
if packed not in inverse_stacked_mapping:
|
||||
inverse_stacked_mapping[packed] = []
|
||||
inverse_stacked_mapping[packed].insert(idx, orig)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (LinearBase, )):
|
||||
last_name = name.split(".")[-1]
|
||||
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
||||
if sub_modules := self.modules_mapping.packed_mapping.get(
|
||||
last_name, []):
|
||||
# Map vllm's names to transformers's names.
|
||||
for sub_name in sub_modules:
|
||||
self.target_modules.append(
|
||||
@ -1018,15 +1009,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}.")
|
||||
|
||||
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet.")
|
||||
"quantization yet. No 'packed_modules_mapping' found.")
|
||||
|
||||
self.modules_mapping = ParamMapping(
|
||||
copy.deepcopy(model.packed_modules_mapping))
|
||||
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||
@ -1109,7 +1104,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
index,
|
||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -49,3 +50,26 @@ def get_model_architecture(
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
packed_mapping: Dict[str, List[str]]
|
||||
inverse_packed_mapping: Dict[str, Tuple[str,
|
||||
int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"c_fc_0": ("gate_up_proj", 0),
|
||||
"c_fc_1": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
@ -409,9 +409,9 @@ class FalconModel(nn.Module):
|
||||
|
||||
|
||||
class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {}
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
|
@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
|
@ -463,14 +463,10 @@ def init_vision_tower_for_llava(
|
||||
info=_build_llava_or_pixtral_hf_info,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
|
@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
||||
# `embedding_modules` and `embedding_padding_modules`
|
||||
# are inherited from MiniCPMForCausalLM
|
||||
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
@ -761,16 +761,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
@ -881,16 +871,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
@ -1107,14 +1107,9 @@ class MllamaForCausalLM(nn.Module):
|
||||
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
"gate_proj": ("merged_linear", 0),
|
||||
"up_proj": ("merged_linear", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
@ -329,13 +329,9 @@ class OPTModel(nn.Module):
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"fc2",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
"gate_up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# BitandBytes specific attributes
|
||||
# Initialize an empty dict when there is no stacked parameter mapping.
|
||||
bitsandbytes_stacked_params_mapping = {}
|
||||
|
@ -1028,13 +1028,6 @@ class QWenLLM(QWenBaseModel):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"w2": ("gate_up_proj", 0),
|
||||
"w1": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
|
||||
class QWenVL(QWenBaseModel, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
|
@ -418,16 +418,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
@ -1038,16 +1038,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
|
@ -401,14 +401,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
Loading…
x
Reference in New Issue
Block a user