[Model] use AutoWeightsLoader in model load_weights (#15770)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
550b2801ad
commit
e86c414d6a
@ -218,6 +218,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.
|
* `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `BambaForCausalLM`
|
||||||
|
* Bamba
|
||||||
|
* `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B`
|
||||||
|
*
|
||||||
|
*
|
||||||
- * `BloomForCausalLM`
|
- * `BloomForCausalLM`
|
||||||
* BLOOM, BLOOMZ, BLOOMChat
|
* BLOOM, BLOOMZ, BLOOMChat
|
||||||
* `bigscience/bloom`, `bigscience/bloomz`, etc.
|
* `bigscience/bloom`, `bigscience/bloomz`, etc.
|
||||||
|
@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType
|
|||||||
|
|
||||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||||
SupportsQuant, SupportsV0Only)
|
SupportsQuant, SupportsV0Only)
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -363,6 +363,58 @@ class BambaModel(nn.Module):
|
|||||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "A_log" in name:
|
||||||
|
name = name.replace("A_log", "A")
|
||||||
|
|
||||||
|
if ".self_attn." in name:
|
||||||
|
name = name.replace(".self_attn", "")
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||||
IsHybrid, SupportsV0Only, SupportsQuant):
|
IsHybrid, SupportsV0Only, SupportsQuant):
|
||||||
@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(self)
|
||||||
# (param_name, shard_name, shard_id)
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "A_log" in name:
|
|
||||||
name = name.replace("A_log", "A")
|
|
||||||
|
|
||||||
if ".self_attn." in name:
|
|
||||||
name = name.replace(".self_attn", "")
|
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -313,6 +313,7 @@ class ExaoneModel(nn.Module):
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
self.vocab_size = config.vocab_size + lora_vocab
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
@ -384,6 +385,72 @@ class ExaoneModel(nn.Module):
|
|||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".c_fc_0", 0),
|
||||||
|
(".gate_up_proj", ".c_fc_1", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(
|
||||||
# (param_name, shard_name, shard_id)
|
self,
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
|
||||||
(".gate_up_proj", ".c_fc_0", 0),
|
|
||||||
(".gate_up_proj", ".c_fc_1", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if ("rotary_emb.cos_cached" in name
|
|
||||||
or "rotary_emb.sin_cached" in name):
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
continue
|
|
||||||
# With tie_word_embeddings, we can skip lm_head.weight
|
# With tie_word_embeddings, we can skip lm_head.weight
|
||||||
# The weight might appear unnecessarily in the files if the model is
|
# The weight might appear unnecessarily in the files if the model is
|
||||||
# processed with quantization, LoRA, fine-tuning, etc.
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
skip_prefixes=(["lm_head."]
|
||||||
continue
|
if self.config.tie_word_embeddings else None),
|
||||||
if (self.quant_config is not None and
|
)
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
return loader.load_weights(weights)
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
||||||
loaded_weight[0])
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.configs import RWConfig
|
from vllm.transformers_utils.configs import RWConfig
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -395,6 +395,54 @@ class FalconModel(nn.Module):
|
|||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
total_num_kv_heads = self.config.num_kv_heads
|
||||||
|
elif self.config.multi_query:
|
||||||
|
total_num_kv_heads = 1
|
||||||
|
else:
|
||||||
|
total_num_kv_heads = total_num_heads
|
||||||
|
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
if "query_key_value" in name:
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
loaded_weight_shape = loaded_weight.shape
|
||||||
|
if output_dim is not None:
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
loaded_weight_shape[:output_dim] +
|
||||||
|
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||||
|
-1) + loaded_weight_shape[output_dim + 1:])
|
||||||
|
wq = loaded_weight.narrow(
|
||||||
|
output_dim + 1, 0,
|
||||||
|
num_query_heads_per_kv_head).reshape(
|
||||||
|
*loaded_weight_shape[:output_dim], -1,
|
||||||
|
*loaded_weight_shape[output_dim + 1:])
|
||||||
|
wk = loaded_weight.narrow(
|
||||||
|
output_dim + 1, num_query_heads_per_kv_head,
|
||||||
|
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||||
|
*loaded_weight_shape[output_dim + 1:])
|
||||||
|
wv = loaded_weight.narrow(
|
||||||
|
output_dim + 1, num_query_heads_per_kv_head + 1,
|
||||||
|
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||||
|
*loaded_weight_shape[output_dim + 1:])
|
||||||
|
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class FalconForCausalLM(nn.Module, SupportsPP):
|
class FalconForCausalLM(nn.Module, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -462,51 +510,9 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
total_num_heads = self.config.num_attention_heads
|
loader = AutoWeightsLoader(
|
||||||
if self.config.new_decoder_architecture:
|
self,
|
||||||
total_num_kv_heads = self.config.num_kv_heads
|
skip_prefixes=(["lm_head."]
|
||||||
elif self.config.multi_query:
|
if self.config.tie_word_embeddings else None),
|
||||||
total_num_kv_heads = 1
|
)
|
||||||
else:
|
return loader.load_weights(weights)
|
||||||
total_num_kv_heads = total_num_heads
|
|
||||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if name == "lm_head.weight" and self.tie_word_embeddings:
|
|
||||||
# Falcon uses tied embeddings except Falcon-11b.
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
if "query_key_value" in name:
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
loaded_weight_shape = loaded_weight.shape
|
|
||||||
if output_dim is not None:
|
|
||||||
loaded_weight = loaded_weight.view(
|
|
||||||
loaded_weight_shape[:output_dim] +
|
|
||||||
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
|
||||||
-1) + loaded_weight_shape[output_dim + 1:])
|
|
||||||
wq = loaded_weight.narrow(
|
|
||||||
output_dim + 1, 0,
|
|
||||||
num_query_heads_per_kv_head).reshape(
|
|
||||||
*loaded_weight_shape[:output_dim], -1,
|
|
||||||
*loaded_weight_shape[output_dim + 1:])
|
|
||||||
wk = loaded_weight.narrow(
|
|
||||||
output_dim + 1, num_query_heads_per_kv_head,
|
|
||||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
|
||||||
*loaded_weight_shape[output_dim + 1:])
|
|
||||||
wv = loaded_weight.narrow(
|
|
||||||
output_dim + 1, num_query_heads_per_kv_head + 1,
|
|
||||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
|
||||||
*loaded_weight_shape[output_dim + 1:])
|
|
||||||
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
|
||||||
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user