[Model] use AutoWeightsLoader for olmoe,opt,orion,persimmon,phi3_small (#16548)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
a018e555fd
commit
5125d72f02
@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ class OlmoeModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
self.config = config
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@ -308,56 +308,6 @@ class OlmoeModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|
||||||
|
|
||||||
fall_back_to_pt_during_load = False
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
self.config = config
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.model = OlmoeModel(vllm_config=vllm_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.sampler = get_sampler()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.model.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: Optional[torch.Tensor],
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@ -380,8 +330,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
@ -453,3 +401,59 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = OlmoeModel(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=["rotary_emb.inv_freq"],
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -313,6 +313,43 @@ class OPTModel(nn.Module):
|
|||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class OPTForCausalLM(nn.Module, SupportsPP):
|
class OPTForCausalLM(nn.Module, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -320,6 +357,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
|||||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||||
|
"decoder.": "model.decoder.",
|
||||||
|
})
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@ -371,42 +412,9 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(
|
||||||
# (param_name, shard_name, shard_id)
|
self,
|
||||||
("qkv_proj", "q_proj", "q"),
|
skip_prefixes=(["lm_head.weight"]
|
||||||
("qkv_proj", "k_proj", "k"),
|
if self.config.tie_word_embeddings else None),
|
||||||
("qkv_proj", "v_proj", "v"),
|
)
|
||||||
]
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "lm_head.weight" in name and self.config.tie_word_embeddings:
|
|
||||||
continue
|
|
||||||
if name.startswith("decoder."):
|
|
||||||
name = "model." + name
|
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
@ -30,7 +30,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -260,6 +260,45 @@ class OrionModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class OrionForCausalLM(nn.Module, SupportsPP):
|
class OrionForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
@ -314,46 +353,14 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(
|
||||||
# (param_name, shard_name, shard_id)
|
self,
|
||||||
("qkv_proj", "q_proj", "q"),
|
skip_prefixes=([
|
||||||
("qkv_proj", "k_proj", "k"),
|
"rotary_emb.inv_freq",
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if ("rotary_emb.cos_cached" in name
|
|
||||||
or "rotary_emb.sin_cached" in name):
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
"rotary_emb.cos_cached",
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
"rotary_emb.sin_cached"
|
||||||
if weight_name not in name:
|
]),
|
||||||
continue
|
)
|
||||||
name = name.replace(weight_name, param_name)
|
return loader.load_weights(weights)
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class PersimmonModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
self.config = config
|
||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
@ -260,6 +260,38 @@ class PersimmonModel(nn.Module):
|
|||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
|
||||||
|
if "query_key_value" in name:
|
||||||
|
# copy from vllm/model_executor/models/bloom.py
|
||||||
|
# NOTE: Persimmon's fused QKV's output_dim has the shape of
|
||||||
|
# (num_heads * 3 * head_size), while the
|
||||||
|
# required shape is (3 * num_heads * head_size).
|
||||||
|
# Thus, we need weight conversion.
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
num_heads = self.config.num_attention_heads
|
||||||
|
if output_dim is not None:
|
||||||
|
loaded_weight_shape = loaded_weight.shape
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||||
|
loaded_weight_shape[output_dim + 1:])
|
||||||
|
loaded_weight = loaded_weight.transpose(
|
||||||
|
output_dim, output_dim + 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class PersimmonForCausalLM(nn.Module, SupportsPP):
|
class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||||
|
|
||||||
@ -315,39 +347,5 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
loader = AutoWeightsLoader(self)
|
||||||
loaded_params: Set[str] = set()
|
return loader.load_weights(weights)
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if ("rotary_emb.cos_cached" in name
|
|
||||||
or "rotary_emb.sin_cached" in name):
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
|
|
||||||
if "query_key_value" in name:
|
|
||||||
# copy from vllm/model_executor/models/bloom.py
|
|
||||||
# NOTE: Persimmon's fused QKV's output_dim has the shape of
|
|
||||||
# (num_heads * 3 * head_size), while the
|
|
||||||
# required shape is (3 * num_heads * head_size).
|
|
||||||
# Thus, we need weight conversion.
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
num_heads = self.config.num_attention_heads
|
|
||||||
if output_dim is not None:
|
|
||||||
loaded_weight_shape = loaded_weight.shape
|
|
||||||
loaded_weight = loaded_weight.view(
|
|
||||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
|
||||||
loaded_weight_shape[output_dim + 1:])
|
|
||||||
loaded_weight = loaded_weight.transpose(
|
|
||||||
output_dim, output_dim + 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
|
||||||
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
@ -26,7 +26,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -353,10 +353,29 @@ class Phi3SmallModel(nn.Module):
|
|||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_suffix={"rotary_emb.inv_freq": None})
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@ -448,21 +467,8 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
params_dict = dict(self.named_parameters())
|
self,
|
||||||
loaded_params: Set[str] = set()
|
skip_prefixes=(["lm_head.weight"]
|
||||||
for name, loaded_weight in weights:
|
if self.config.tie_word_embeddings else None))
|
||||||
if "rotary_emb.inv_freq" in name:
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
continue
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if "lm_head.weight" in name and self.config.tie_word_embeddings:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user