[BugFix] Fix Falcon tied embeddings (#3590)
Co-authored-by: 44670 <44670@users.noreply.github.com>
This commit is contained in:
parent
f8a12ecc7f
commit
af9e53496f
@ -37,7 +37,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
@ -370,10 +370,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.transformer = FalconModel(config, linear_method)
|
self.transformer = FalconModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
)
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@ -394,7 +391,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@ -419,9 +416,12 @@ class FalconForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
total_num_kv_heads = total_num_heads
|
total_num_kv_heads = total_num_heads
|
||||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format, revision):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if name == "lm_head.weight":
|
||||||
|
# Falcon uses tied embeddings.
|
||||||
|
continue
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
Loading…
x
Reference in New Issue
Block a user