[Model] use AutoWeightsLoader for granite, granitemoe, granitemoeshared, grok1, mixtral (#16325)
Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>
This commit is contained in:
parent
1da6a09274
commit
a564797151
@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
maybe_prefix)
|
make_layers, maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
class GraniteMLP(nn.Module):
|
class GraniteMLP(nn.Module):
|
||||||
@ -260,6 +260,7 @@ class GraniteModel(nn.Module):
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
self.vocab_size = config.vocab_size + lora_vocab
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
@ -321,6 +322,65 @@ class GraniteModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -428,71 +488,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
skip_prefixes = [
|
||||||
# (param_name, shard_name, shard_id)
|
"rotary_emb.inv_freq",
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
# Models trained using ColossalAI may include these tensors in
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
# the checkpoint. Skip them.
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
"rotary_emb.cos_cached",
|
||||||
(".gate_up_proj", ".gate_proj", 0),
|
"rotary_emb.sin_cached",
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
# With tie_word_embeddings, we can skip lm_head.weight
|
||||||
loaded_params: Set[str] = set()
|
# The weight might appear unnecessarily in the files if the model is
|
||||||
for name, loaded_weight in weights:
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
if "rotary_emb.inv_freq" in name:
|
if self.config.tie_word_embeddings:
|
||||||
continue
|
skip_prefixes.append("lm_head.weight")
|
||||||
if ("rotary_emb.cos_cached" in name
|
|
||||||
or "rotary_emb.sin_cached" in name):
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
continue
|
|
||||||
# With tie_word_embeddings, we can skip lm_head.weight
|
|
||||||
# The weight might appear unnecessarily in the files if the model is
|
|
||||||
# processed with quantization, LoRA, fine-tuning, etc.
|
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
if (self.quant_config is not None and
|
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
||||||
loaded_weight[0])
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||||
continue
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
|
|
||||||
from . import mixtral
|
from . import mixtral
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import make_layers, maybe_prefix
|
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeMoE(nn.Module):
|
class GraniteMoeMoE(nn.Module):
|
||||||
@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config # Required by MixtralModel
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
self.vocab_size = config.vocab_size + lora_vocab
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
new_weights = {}
|
||||||
|
for n, p in weights:
|
||||||
|
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
||||||
|
for e in range(p.size(0)):
|
||||||
|
w1_name = n.replace(
|
||||||
|
'.block_sparse_moe.input_linear.weight',
|
||||||
|
f".block_sparse_moe.experts.{e}.w1.weight")
|
||||||
|
w3_name = n.replace(
|
||||||
|
'.block_sparse_moe.input_linear.weight',
|
||||||
|
f".block_sparse_moe.experts.{e}.w3.weight")
|
||||||
|
w1_param, w3_param = p[e].chunk(2, dim=0)
|
||||||
|
assert w1_name not in new_weights
|
||||||
|
assert w3_name not in new_weights
|
||||||
|
new_weights[w1_name] = w1_param
|
||||||
|
new_weights[w3_name] = w3_param
|
||||||
|
elif n.endswith('.block_sparse_moe.output_linear.weight'):
|
||||||
|
for e in range(p.size(0)):
|
||||||
|
w2_name = n.replace(
|
||||||
|
'.block_sparse_moe.output_linear.weight',
|
||||||
|
f".block_sparse_moe.experts.{e}.w2.weight")
|
||||||
|
w2_param = p[e]
|
||||||
|
assert w2_name not in new_weights
|
||||||
|
new_weights[w2_name] = w2_param
|
||||||
|
elif n.endswith('.block_sparse_moe.router.layer.weight'):
|
||||||
|
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
|
||||||
|
".block_sparse_moe.gate.weight")
|
||||||
|
assert gate_name not in new_weights
|
||||||
|
new_weights[gate_name] = p
|
||||||
|
else:
|
||||||
|
new_weights[n] = p
|
||||||
|
return mixtral.MixtralModel.load_weights(self, new_weights.items())
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.quant_config = quant_config # Required by MixtralForCausalLM
|
|
||||||
|
|
||||||
self.model = GraniteMoeModel(vllm_config=vllm_config,
|
self.model = GraniteMoeModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
new_weights = {}
|
loader = AutoWeightsLoader(
|
||||||
for n, p in weights:
|
self,
|
||||||
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
skip_prefixes=(["lm_head."]
|
||||||
for e in range(p.size(0)):
|
if self.config.tie_word_embeddings else None),
|
||||||
w1_name = n.replace(
|
)
|
||||||
'.block_sparse_moe.input_linear.weight',
|
return loader.load_weights(weights)
|
||||||
f".block_sparse_moe.experts.{e}.w1.weight")
|
|
||||||
w3_name = n.replace(
|
|
||||||
'.block_sparse_moe.input_linear.weight',
|
|
||||||
f".block_sparse_moe.experts.{e}.w3.weight")
|
|
||||||
w1_param, w3_param = p[e].chunk(2, dim=0)
|
|
||||||
assert w1_name not in new_weights
|
|
||||||
assert w3_name not in new_weights
|
|
||||||
new_weights[w1_name] = w1_param
|
|
||||||
new_weights[w3_name] = w3_param
|
|
||||||
elif n.endswith('.block_sparse_moe.output_linear.weight'):
|
|
||||||
for e in range(p.size(0)):
|
|
||||||
w2_name = n.replace(
|
|
||||||
'.block_sparse_moe.output_linear.weight',
|
|
||||||
f".block_sparse_moe.experts.{e}.w2.weight")
|
|
||||||
w2_param = p[e]
|
|
||||||
assert w2_name not in new_weights
|
|
||||||
new_weights[w2_name] = w2_param
|
|
||||||
elif n.endswith('.block_sparse_moe.router.layer.weight'):
|
|
||||||
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
|
|
||||||
".block_sparse_moe.gate.weight")
|
|
||||||
assert gate_name not in new_weights
|
|
||||||
new_weights[gate_name] = p
|
|
||||||
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
new_weights[n] = p
|
|
||||||
return mixtral.MixtralForCausalLM.load_weights(self,
|
|
||||||
new_weights.items())
|
|
||||||
|
@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from . import mixtral
|
from . import mixtral
|
||||||
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
|
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import make_layers, maybe_prefix
|
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeSharedMLP(nn.Module):
|
class GraniteMoeSharedMLP(nn.Module):
|
||||||
@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config # Required by MixtralModel
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
new_weights = {}
|
||||||
|
for n, p in weights:
|
||||||
|
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
||||||
|
for e in range(p.size(0)):
|
||||||
|
w1_name = n.replace(
|
||||||
|
'.block_sparse_moe.input_linear.weight',
|
||||||
|
f".block_sparse_moe.experts.{e}.w1.weight")
|
||||||
|
w3_name = n.replace(
|
||||||
|
'.block_sparse_moe.input_linear.weight',
|
||||||
|
f".block_sparse_moe.experts.{e}.w3.weight")
|
||||||
|
w1_param, w3_param = p[e].chunk(2, dim=0)
|
||||||
|
assert w1_name not in new_weights
|
||||||
|
assert w3_name not in new_weights
|
||||||
|
new_weights[w1_name] = w1_param
|
||||||
|
new_weights[w3_name] = w3_param
|
||||||
|
elif n.endswith('.block_sparse_moe.output_linear.weight'):
|
||||||
|
for e in range(p.size(0)):
|
||||||
|
w2_name = n.replace(
|
||||||
|
'.block_sparse_moe.output_linear.weight',
|
||||||
|
f".block_sparse_moe.experts.{e}.w2.weight")
|
||||||
|
w2_param = p[e]
|
||||||
|
assert w2_name not in new_weights
|
||||||
|
new_weights[w2_name] = w2_param
|
||||||
|
elif n.endswith('.block_sparse_moe.router.layer.weight'):
|
||||||
|
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
|
||||||
|
".block_sparse_moe.gate.weight")
|
||||||
|
assert gate_name not in new_weights
|
||||||
|
new_weights[gate_name] = p
|
||||||
|
else:
|
||||||
|
new_weights[n] = p
|
||||||
|
return mixtral.MixtralModel.load_weights(self, new_weights.items())
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
self.model = GraniteMoeSharedModel(vllm_config=vllm_config,
|
self.model = GraniteMoeSharedModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
new_weights = {}
|
loader = AutoWeightsLoader(
|
||||||
for n, p in weights:
|
self,
|
||||||
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
skip_prefixes=(["lm_head."]
|
||||||
for e in range(p.size(0)):
|
if self.config.tie_word_embeddings else None),
|
||||||
w1_name = n.replace(
|
)
|
||||||
'.block_sparse_moe.input_linear.weight',
|
return loader.load_weights(weights)
|
||||||
f".block_sparse_moe.experts.{e}.w1.weight")
|
|
||||||
w3_name = n.replace(
|
|
||||||
'.block_sparse_moe.input_linear.weight',
|
|
||||||
f".block_sparse_moe.experts.{e}.w3.weight")
|
|
||||||
w1_param, w3_param = p[e].chunk(2, dim=0)
|
|
||||||
assert w1_name not in new_weights
|
|
||||||
assert w3_name not in new_weights
|
|
||||||
new_weights[w1_name] = w1_param
|
|
||||||
new_weights[w3_name] = w3_param
|
|
||||||
elif n.endswith('.block_sparse_moe.output_linear.weight'):
|
|
||||||
for e in range(p.size(0)):
|
|
||||||
w2_name = n.replace(
|
|
||||||
'.block_sparse_moe.output_linear.weight',
|
|
||||||
f".block_sparse_moe.experts.{e}.w2.weight")
|
|
||||||
w2_param = p[e]
|
|
||||||
assert w2_name not in new_weights
|
|
||||||
new_weights[w2_name] = w2_param
|
|
||||||
elif n.endswith('.block_sparse_moe.router.layer.weight'):
|
|
||||||
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
|
|
||||||
".block_sparse_moe.gate.weight")
|
|
||||||
assert gate_name not in new_weights
|
|
||||||
new_weights[gate_name] = p
|
|
||||||
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
new_weights[n] = p
|
|
||||||
return mixtral.MixtralForCausalLM.load_weights(self,
|
|
||||||
new_weights.items())
|
|
||||||
|
@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
@ -370,6 +372,105 @@ class Grok1Model(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Map Grok1's unique expert parameter names to standard names
|
||||||
|
# Grok1 uses "num_experts" in its config
|
||||||
|
num_experts = getattr(self.config, "num_experts", 8)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="linear", # Grok1 specific
|
||||||
|
ckpt_down_proj_name="linear_1", # Grok1 specific
|
||||||
|
ckpt_up_proj_name="linear_v", # Grok1 specific
|
||||||
|
num_experts=num_experts)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if name.endswith("scale"):
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle Grok1-specific norm.scale naming
|
||||||
|
if "norm.scale" in name:
|
||||||
|
name = name.replace("scale", "weight")
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
@ -460,106 +561,10 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
skip_prefixes = ["rotary_emb.inv_freq"]
|
||||||
# (param_name, shard_name, shard_id)
|
# Skip lm_head when tie_word_embeddings is True
|
||||||
("qkv_proj", "q_proj", "q"),
|
if self.config.tie_word_embeddings:
|
||||||
("qkv_proj", "k_proj", "k"),
|
skip_prefixes.append("lm_head")
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Map Grok1's unique expert parameter names to standard names
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||||
# Grok1 uses "num_experts" in its config
|
return loader.load_weights(weights)
|
||||||
num_experts = getattr(self.config, "num_experts", 8)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="linear", # Grok1 specific
|
|
||||||
ckpt_down_proj_name="linear_1", # Grok1 specific
|
|
||||||
ckpt_up_proj_name="linear_v", # Grok1 specific
|
|
||||||
num_experts=num_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (self.quant_config is not None and
|
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
||||||
loaded_weight[0])
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if name.endswith("scale"):
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for mapping in expert_params_mapping:
|
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param,
|
|
||||||
loaded_weight,
|
|
||||||
name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=expert_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle Grok1-specific norm.scale naming
|
|
||||||
if "norm.scale" in name:
|
|
||||||
name = name.replace("scale", "weight")
|
|
||||||
|
|
||||||
# Skip lm_head when tie_word_embeddings is True
|
|
||||||
if "lm_head" in name and self.config.tie_word_embeddings:
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -260,6 +260,8 @@ class MixtralModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
self.vocab_size = config.vocab_size + lora_vocab
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
@ -313,6 +315,98 @@ class MixtralModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="w1",
|
||||||
|
ckpt_down_proj_name="w2",
|
||||||
|
ckpt_up_proj_name="w3",
|
||||||
|
num_experts=self.config.num_local_experts)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if name.endswith("scale"):
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
@ -397,95 +491,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
|
||||||
# (param_name, shard_name, shard_id)
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="w1",
|
|
||||||
ckpt_down_proj_name="w2",
|
|
||||||
ckpt_up_proj_name="w3",
|
|
||||||
num_experts=self.config.num_local_experts)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (self.quant_config is not None and
|
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
||||||
loaded_weight[0])
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if name.endswith("scale"):
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for mapping in expert_params_mapping:
|
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param,
|
|
||||||
loaded_weight,
|
|
||||||
name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=expert_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
# Skip layers on other devices.
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user