[Model][LoRA]LoRA support added for Qwen (#9622)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-10-29 12:14:07 +08:00 committed by GitHub
parent c5d7fb9ddc
commit 7a4df5f200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 14 deletions

View File

@ -578,10 +578,10 @@ class LoRAModelManager(AdapterModelManager):
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _register_packed_modules(self, module_full_name: str) -> None:

View File

@ -20,7 +20,7 @@ from torchvision.transforms import InterpolationMode
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
@ -30,6 +30,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -39,6 +40,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@ -46,7 +48,7 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@ -122,8 +124,8 @@ class VisualAttention(nn.Module):
# Strided linear layer.
assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(
@ -133,7 +135,7 @@ class VisualAttention(nn.Module):
) -> torch.Tensor:
# query/key/value: [sq, b, h]
sq, b, _ = x.size()
mixed_x_layer = self.in_proj(x)
mixed_x_layer, _ = self.in_proj(x)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
@ -182,7 +184,7 @@ class VisualAttention(nn.Module):
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
output, _ = self.out_proj(context_layer)
return output
@ -860,11 +862,7 @@ def dummy_data_for_qwen(
return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__(
self,
@ -872,6 +870,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
@ -990,3 +989,91 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class QWenLLM(QWenBaseModel):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"c_attn",
"gate_up_proj",
"c_proj",
]
embedding_modules = {}
embedding_padding_modules = []
class QWenVL(QWenBaseModel):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"c_attn",
"gate_up_proj",
"c_proj",
# visual module
"out_proj",
"in_proj",
"c_fc",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="transformer.h",
connector="transformer.visual.attn_pool",
tower_model="transformer.visual.transformer")
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(QWenBaseModel):
"""
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
conducive to the current integration logic of LoRA in vLLM. Therefore, it
is necessary to separate them.
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
def __new__(
cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
# Initialize VL
if hasattr(config, "visual"):
return QWenVL(config, multimodal_config, cache_config,
quant_config, lora_config)
# Initialize LLM
else:
return QWenLLM(config, multimodal_config, cache_config,
quant_config, lora_config)