[Model][LoRA]LoRA support added for Qwen (#9622)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
c5d7fb9ddc
commit
7a4df5f200
@ -578,10 +578,10 @@ class LoRAModelManager(AdapterModelManager):
|
||||
be filtered out.
|
||||
"""
|
||||
if self.supports_mm:
|
||||
prefix = module_name.split(".")[0]
|
||||
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
return (prefix in module_mapping.connector
|
||||
or prefix in module_mapping.tower_model)
|
||||
prefix_lst = module_mapping.connector + module_mapping.tower_model
|
||||
return any(
|
||||
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return False
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
|
@ -20,7 +20,7 @@ from torchvision.transforms import InterpolationMode
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -39,6 +40,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@ -46,7 +48,7 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
@ -122,8 +124,8 @@ class VisualAttention(nn.Module):
|
||||
# Strided linear layer.
|
||||
assert self._qkv_same_embed_dim, \
|
||||
'Visual Attention implementation only supports self-attention'
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
|
||||
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
def forward(
|
||||
@ -133,7 +135,7 @@ class VisualAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# query/key/value: [sq, b, h]
|
||||
sq, b, _ = x.size()
|
||||
mixed_x_layer = self.in_proj(x)
|
||||
mixed_x_layer, _ = self.in_proj(x)
|
||||
|
||||
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||
@ -182,7 +184,7 @@ class VisualAttention(nn.Module):
|
||||
(self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output = self.out_proj(context_layer)
|
||||
output, _ = self.out_proj(context_layer)
|
||||
|
||||
return output
|
||||
|
||||
@ -860,11 +862,7 @@ def dummy_data_for_qwen(
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
||||
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -872,6 +870,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -990,3 +989,91 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class QWenLLM(QWenBaseModel):
|
||||
packed_modules_mapping = {
|
||||
"c_attn": ["c_attn"],
|
||||
"gate_up_proj": [
|
||||
"w2",
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"c_attn",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
||||
class QWenVL(QWenBaseModel):
|
||||
packed_modules_mapping = {
|
||||
"c_attn": ["c_attn"],
|
||||
"gate_up_proj": [
|
||||
"w2",
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"c_attn",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
# visual module
|
||||
"out_proj",
|
||||
"in_proj",
|
||||
"c_fc",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.h",
|
||||
connector="transformer.visual.attn_pool",
|
||||
tower_model="transformer.visual.transformer")
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
||||
class QWenLMHeadModel(QWenBaseModel):
|
||||
"""
|
||||
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
|
||||
conducive to the current integration logic of LoRA in vLLM. Therefore, it
|
||||
is necessary to separate them.
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
# Initialize VL
|
||||
if hasattr(config, "visual"):
|
||||
return QWenVL(config, multimodal_config, cache_config,
|
||||
quant_config, lora_config)
|
||||
# Initialize LLM
|
||||
else:
|
||||
return QWenLLM(config, multimodal_config, cache_config,
|
||||
quant_config, lora_config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user