diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index af0f7304..bf7e2b5b 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -218,6 +218,11 @@ See [this page](#generative-models) for more information on how to use generativ * `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. * ✅︎ * ✅︎ +- * `BambaForCausalLM` + * Bamba + * `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` + * + * - * `BloomForCausalLM` * BLOOM, BLOOMZ, BLOOMChat * `bigscience/bloom`, `bigscience/bloomz`, etc. diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index de0209d0..e5896f5f 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant, SupportsV0Only) -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -363,6 +363,58 @@ class BambaModel(nn.Module): hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsV0Only, SupportsQuant): @@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if "A_log" in name: - name = name.replace("A_log", "A") - - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 7d01dd37..553c524e 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.exaone import ExaoneConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -313,6 +313,7 @@ class ExaoneModel(nn.Module): lora_config = vllm_config.lora_config self.config = config + self.quant_config = quant_config lora_vocab = ((lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab @@ -384,6 +385,72 @@ class ExaoneModel(nn.Module): hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".c_fc_0", 0), - (".gate_up_proj", ".c_fc_1", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue + loader = AutoWeightsLoader( + self, # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7154ac2e..0e67b1ec 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -395,6 +395,54 @@ class FalconModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + total_num_heads = self.config.num_attention_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + if "query_key_value" in name: + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class FalconForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { @@ -462,51 +510,9 @@ class FalconForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - total_num_heads = self.config.num_attention_heads - if self.config.new_decoder_architecture: - total_num_kv_heads = self.config.num_kv_heads - elif self.config.multi_query: - total_num_kv_heads = 1 - else: - total_num_kv_heads = total_num_heads - num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if name == "lm_head.weight" and self.tie_word_embeddings: - # Falcon uses tied embeddings except Falcon-11b. - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - if "query_key_value" in name: - output_dim = getattr(param, "output_dim", None) - loaded_weight_shape = loaded_weight.shape - if output_dim is not None: - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + - (total_num_kv_heads, num_query_heads_per_kv_head + 2, - -1) + loaded_weight_shape[output_dim + 1:]) - wq = loaded_weight.narrow( - output_dim + 1, 0, - num_query_heads_per_kv_head).reshape( - *loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - wk = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - wv = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head + 1, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights)