[Model][LoRA]LoRA support added for Qwen2VLForConditionalGeneration (#10022)

Signed-off-by: ericperfect <ericperfectttt@gmail.com>
This commit is contained in:
Eric 2024-11-06 22:13:15 +08:00 committed by GitHub
parent a5bba7d234
commit 406d4cc480
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 5 deletions

View File

@ -540,7 +540,7 @@ Text Generation
- Qwen2-VL
- T + I\ :sup:`E+` + V\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
- ✅︎
- ✅︎
* - :code:`UltravoxModel`
- Ultravox

View File

@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
@ -65,7 +65,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
@ -927,13 +927,37 @@ def input_processor_for_qwen2_vl(
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
# TODO Support LoRA for the visual encoder in the future.
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None) -> None:
super().__init__()
assert not cache_config.enable_prefix_caching, \