[Model][LoRA]LoRA support added for Qwen (#9622)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
c5d7fb9ddc
commit
7a4df5f200
@ -578,10 +578,10 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
be filtered out.
|
be filtered out.
|
||||||
"""
|
"""
|
||||||
if self.supports_mm:
|
if self.supports_mm:
|
||||||
prefix = module_name.split(".")[0]
|
|
||||||
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||||
return (prefix in module_mapping.connector
|
prefix_lst = module_mapping.connector + module_mapping.tower_model
|
||||||
or prefix in module_mapping.tower_model)
|
return any(
|
||||||
|
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
|
@ -20,7 +20,7 @@ from torchvision.transforms import InterpolationMode
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||||
token_inputs)
|
token_inputs)
|
||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -39,6 +40,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
@ -46,7 +48,7 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
|||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers)
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
@ -122,8 +124,8 @@ class VisualAttention(nn.Module):
|
|||||||
# Strided linear layer.
|
# Strided linear layer.
|
||||||
assert self._qkv_same_embed_dim, \
|
assert self._qkv_same_embed_dim, \
|
||||||
'Visual Attention implementation only supports self-attention'
|
'Visual Attention implementation only supports self-attention'
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
|
||||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -133,7 +135,7 @@ class VisualAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# query/key/value: [sq, b, h]
|
# query/key/value: [sq, b, h]
|
||||||
sq, b, _ = x.size()
|
sq, b, _ = x.size()
|
||||||
mixed_x_layer = self.in_proj(x)
|
mixed_x_layer, _ = self.in_proj(x)
|
||||||
|
|
||||||
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
||||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||||
@ -182,7 +184,7 @@ class VisualAttention(nn.Module):
|
|||||||
(self.hidden_size_per_partition,)
|
(self.hidden_size_per_partition,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
output = self.out_proj(context_layer)
|
output, _ = self.out_proj(context_layer)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -860,11 +862,7 @@ def dummy_data_for_qwen(
|
|||||||
return seq_data, mm_data
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
|
||||||
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -872,6 +870,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
multimodal_config: MultiModalConfig,
|
multimodal_config: MultiModalConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -990,3 +989,91 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class QWenLLM(QWenBaseModel):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"c_attn": ["c_attn"],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"w2",
|
||||||
|
"w1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"c_attn",
|
||||||
|
"gate_up_proj",
|
||||||
|
"c_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
|
||||||
|
class QWenVL(QWenBaseModel):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"c_attn": ["c_attn"],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"w2",
|
||||||
|
"w1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"c_attn",
|
||||||
|
"gate_up_proj",
|
||||||
|
"c_proj",
|
||||||
|
# visual module
|
||||||
|
"out_proj",
|
||||||
|
"in_proj",
|
||||||
|
"c_fc",
|
||||||
|
# resampler
|
||||||
|
"kv_proj",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""
|
||||||
|
Get the module prefix in multimodal models
|
||||||
|
"""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="transformer.h",
|
||||||
|
connector="transformer.visual.attn_pool",
|
||||||
|
tower_model="transformer.visual.transformer")
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
||||||
|
class QWenLMHeadModel(QWenBaseModel):
|
||||||
|
"""
|
||||||
|
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
|
||||||
|
conducive to the current integration logic of LoRA in vLLM. Therefore, it
|
||||||
|
is necessary to separate them.
|
||||||
|
"""
|
||||||
|
# Ensure that the LoRA support check passes when the class is not
|
||||||
|
# initialized, but set all these attributes to empty.
|
||||||
|
packed_modules_mapping = {}
|
||||||
|
supported_lora_modules = []
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
):
|
||||||
|
# Initialize VL
|
||||||
|
if hasattr(config, "visual"):
|
||||||
|
return QWenVL(config, multimodal_config, cache_config,
|
||||||
|
quant_config, lora_config)
|
||||||
|
# Initialize LLM
|
||||||
|
else:
|
||||||
|
return QWenLLM(config, multimodal_config, cache_config,
|
||||||
|
quant_config, lora_config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user