[Model] use AutoWeightsLoader in model load_weights (#15770)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng 2025-04-02 22:47:31 +08:00 committed by GitHub
parent 550b2801ad
commit e86c414d6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 189 additions and 165 deletions

View File

@ -218,6 +218,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.
* ✅︎
* ✅︎
- * `BambaForCausalLM`
* Bamba
* `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B`
*
*
- * `BloomForCausalLM`
* BLOOM, BLOOMZ, BLOOMChat
* `bigscience/bloom`, `bigscience/bloomz`, etc.

View File

@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -363,6 +363,58 @@ class BambaModel(nn.Module):
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -313,6 +313,7 @@ class ExaoneModel(nn.Module):
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
@ -384,6 +385,72 @@ class ExaoneModel(nn.Module):
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
loader = AutoWeightsLoader(
self,
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -395,6 +395,54 @@ class FalconModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape
if output_dim is not None:
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] +
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
-1) + loaded_weight_shape[output_dim + 1:])
wq = loaded_weight.narrow(
output_dim + 1, 0,
num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class FalconForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
@ -462,51 +510,9 @@ class FalconForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape
if output_dim is not None:
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] +
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
-1) + loaded_weight_shape[output_dim + 1:])
wq = loaded_weight.narrow(
output_dim + 1, 0,
num_query_heads_per_kv_head).reshape(
*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wk = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
wv = loaded_weight.narrow(
output_dim + 1, num_query_heads_per_kv_head + 1,
1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)