[Bugfix]Fix MiniCPM's LoRA bug (#9286)
This commit is contained in:
parent
2b184ddd4f
commit
250e26a63e
@ -337,7 +337,11 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self.packed_modules_mapping = copy.deepcopy(
|
self.packed_modules_mapping = copy.deepcopy(
|
||||||
self.model.packed_modules_mapping)
|
self.model.packed_modules_mapping)
|
||||||
# Used to indicate whether the model is a multimodal model
|
# Used to indicate whether the model is a multimodal model
|
||||||
self.supports_mm: bool = supports_multimodal(self.model)
|
self.supports_mm: bool = (
|
||||||
|
supports_multimodal(self.model)
|
||||||
|
# In case the model only supports LoRA for
|
||||||
|
# text modules (e.g. ChatGLM)
|
||||||
|
and hasattr(self.model, "get_mm_mapping"))
|
||||||
self.packed_modules: Dict[str, List[str]] = {}
|
self.packed_modules: Dict[str, List[str]] = {}
|
||||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||||
# Dict instead of a Set for compatibility with LRUCache.
|
# Dict instead of a Set for compatibility with LRUCache.
|
||||||
|
@ -474,17 +474,18 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
unpadded_vocab_size = config.vocab_size
|
unpadded_vocab_size = config.vocab_size
|
||||||
if lora_config:
|
if lora_config:
|
||||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
if not self.config.tie_word_embeddings:
|
self.lm_head = ParallelLMHead(
|
||||||
self.lm_head = ParallelLMHead(
|
unpadded_vocab_size,
|
||||||
unpadded_vocab_size,
|
config.hidden_size,
|
||||||
config.hidden_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
org_num_embeddings=config.vocab_size,
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
# We need bigger padding if using lora for kernel
|
||||||
# We need bigger padding if using lora for kernel
|
# compatibility
|
||||||
# compatibility
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
)
|
||||||
)
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
|
||||||
@ -517,11 +518,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
hidden_states = hidden_states / self.scale_width
|
hidden_states = hidden_states / self.scale_width
|
||||||
if self.config.tie_word_embeddings:
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
lm_head = self.model.embed_tokens
|
|
||||||
else:
|
|
||||||
lm_head = self.lm_head
|
|
||||||
logits = self.logits_processor(lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@ -216,6 +216,28 @@ class MiniCPM3Model(MiniCPMModel):
|
|||||||
|
|
||||||
|
|
||||||
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"kv_a_proj_with_mqa",
|
||||||
|
"q_a_proj",
|
||||||
|
"q_b_proj",
|
||||||
|
"kv_b_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
|
||||||
|
# `embedding_modules` and `embedding_padding_modules`
|
||||||
|
# are inherited from MiniCPMForCausalLM
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
self.model = MiniCPM3Model(config=self.config,
|
self.model = MiniCPM3Model(config=self.config,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user