[Model][LoRA]LoRA support added for Qwen (#9622)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-10-29 12:14:07 +08:00 committed by GitHub
parent c5d7fb9ddc
commit 7a4df5f200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 14 deletions

View File

@ -578,10 +578,10 @@ class LoRAModelManager(AdapterModelManager):
be filtered out. be filtered out.
""" """
if self.supports_mm: if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping() module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector prefix_lst = module_mapping.connector + module_mapping.tower_model
or prefix in module_mapping.tower_model) return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
return False return False
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None:

View File

@ -20,7 +20,7 @@ from torchvision.transforms import InterpolationMode
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs) token_inputs)
@ -30,6 +30,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -39,6 +40,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
@ -46,7 +48,7 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter, from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
@ -122,8 +124,8 @@ class VisualAttention(nn.Module):
# Strided linear layer. # Strided linear layer.
assert self._qkv_same_embed_dim, \ assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention' 'Visual Attention implementation only supports self-attention'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward( def forward(
@ -133,7 +135,7 @@ class VisualAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# query/key/value: [sq, b, h] # query/key/value: [sq, b, h]
sq, b, _ = x.size() sq, b, _ = x.size()
mixed_x_layer = self.in_proj(x) mixed_x_layer, _ = self.in_proj(x)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
@ -182,7 +184,7 @@ class VisualAttention(nn.Module):
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer) output, _ = self.out_proj(context_layer)
return output return output
@ -860,11 +862,7 @@ def dummy_data_for_qwen(
return seq_data, mm_data return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
@ -872,6 +870,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_config: MultiModalConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -990,3 +989,91 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
class QWenLLM(QWenBaseModel):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"c_attn",
"gate_up_proj",
"c_proj",
]
embedding_modules = {}
embedding_padding_modules = []
class QWenVL(QWenBaseModel):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"c_attn",
"gate_up_proj",
"c_proj",
# visual module
"out_proj",
"in_proj",
"c_fc",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="transformer.h",
connector="transformer.visual.attn_pool",
tower_model="transformer.visual.transformer")
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(QWenBaseModel):
"""
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
conducive to the current integration logic of LoRA in vLLM. Therefore, it
is necessary to separate them.
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
def __new__(
cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
# Initialize VL
if hasattr(config, "visual"):
return QWenVL(config, multimodal_config, cache_config,
quant_config, lora_config)
# Initialize LLM
else:
return QWenLLM(config, multimodal_config, cache_config,
quant_config, lora_config)