[Model][LoRA]LoRA support added for MolmoForCausalLM (#11439)

Signed-off-by: Matthias Vogler <matthias.vogler@joesecurity.org>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Matthias Vogler <matthias.vogler@joesecurity.org>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Matthias Vogler 2024-12-31 02:33:06 +01:00 committed by GitHub
parent ccb1aabcca
commit a2a40bcd0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 4 deletions

View File

@ -666,7 +666,7 @@ See [this page](#generative-models) for more information on how to use generativ
- Molmo - Molmo
- T + I - T + I
- `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. - `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc.
- - ✅︎
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - `NVLM_D_Model` * - `NVLM_D_Model`

View File

@ -36,6 +36,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
@ -43,7 +44,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
@ -1161,8 +1162,8 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
# vision backbone mapping # vision backbone mapping
@ -1191,6 +1192,32 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
}, },
) )
packed_modules_mapping = {
"qkv_proj": ["qkv_proj"],
"gate_up_proj": ["gate_up_proj"], # language model
"merged_linear": ["gate_proj", "up_proj"] # image_projector
}
# LoRA specific attributes
supported_lora_modules = [
# language model
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj", # same name with image_projector
# vision tower
"wq",
"wk",
"wv",
"wo",
"w1",
"w2",
# image_projector
"merged_linear",
]
embedding_modules = {}
embedding_padding_modules = []
# BitandBytes specific attributes # BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0), "gate_proj": ("merged_linear", 0),
@ -1202,8 +1229,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.lora_config = lora_config
vision_config = VisionBackboneConfig() vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config, self.vision_backbone = MolmoVisionBackbone(config, vision_config,
@ -1377,6 +1406,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
weights = _get_weights_with_merged_embedding(weights) weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model",
connector="vision_backbone.image_projector",
tower_model="vision_backbone",
)
def _get_weights_with_merged_embedding( def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]] weights: Iterable[Tuple[str, torch.Tensor]]