[VLM] Add TP support for Phi-4-MM (#14453)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-08 21:57:14 +08:00 committed by GitHub
parent cb8bdfade2
commit 03fe18ae0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 295 deletions

View File

@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count},
)
lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here.

View File

@ -15,7 +15,7 @@ from transformers import PretrainedConfig
from transformers.utils import logging
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.inputs.data import TokenInputs, token_inputs
@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
from .utils import maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision_siglip_navit import get_siglip_vision_model
# <|endoftext10|> (see vocab.json in hf model)
@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
embd_drop = config.embd_pdrop if hasattr(
config, 'embd_pdrop') else config.embed_pdrop
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# layer_idx to output the img features
if isinstance(config.img_processor, dict):
@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"base_layer.": "",
},
orig_to_new_prefix={
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
"embed_tokens_extend.audio_projection_for_vision.",
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
"embed_tokens_extend.audio_projection.",
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self.lora_config = lora_config
# Tensor/Pipeline parallel not supported for now.
assert get_tensor_model_parallel_world_size(
) == 1, "tensor parallel is not supported"
assert get_pp_group(
).world_size == 1, "pipeline parallel is not supported"
@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
)
return merged_embeds
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
# NOTE vision-speech tasks use a separate projection layer
audio_proj_4v = \
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
if name.startswith(audio_proj_4v):
name = name.replace(
audio_proj_4v,
"embed_tokens_extend.audio_projection_for_vision")
name = (name.replace(
"model.embed_tokens_extend.audio_embed."\
"audio_projection.speech.",
"embed_tokens_extend.audio_projection.",
).replace(
"model.embed_tokens_extend.audio_embed.",
"embed_tokens_extend.",
).replace("model.embed_tokens_extend.image_embed.",
"vision_encoder."))
# NOTE: this is deal with LoRA injection, where `base_layer`
# remains as the original layer in the model
if name.endswith(".base_layer.weight"):
name = name.replace(".base_layer.weight", ".weight")
adjusted_weights[name] = weight
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
logger.debug("*** missing keys:")
for key in missing_keys:
logger.debug(key)
logger.debug("**** unexpected keys:")
for key in unexpected_keys:
logger.debug(key)
def forward(
self,
input_ids: torch.Tensor,
@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
@ -1804,4 +1779,4 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
language_model="model.",
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)
)

View File

@ -6,69 +6,26 @@
#!/usr/bin/env python3
import abc
import math
from functools import partial
from typing import Callable, Dict, List, Literal, Optional, Union
from typing import List, Literal, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
CheckpointWrapper)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel)
from torch.utils.checkpoint import checkpoint
from transformers import PretrainedConfig
from vllm.model_executor.models.phi4mm_utils import (
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
def encoder_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict],
layer_cls: type,
idx: int = 0,
) -> Callable:
"""return encoder activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
if target_layer_cls.lower() == "transformer":
target_layer_cls = (
"EncoderLayer",
"ConformerEncoderLayer",
)
elif target_layer_cls.lower() == "attention":
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
checkpointing_interval = activation_checkpointing.get("interval", 1)
offloading = activation_checkpointing.get("offload", False)
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
if (idx % checkpointing_interval == 0
and layer_cls.__name__ in target_layer_cls):
if offloading:
return offload_wrapper
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
class ConformerEncoderLayer(nn.Module):
"""ConformerEncoder Layer module.
for more details see conformer paper:
@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
bias_in_glu=bias_in_glu,
)
self.self_attn = encoder_checkpoint_wrapper(
activation_checkpointing,
MultiHeadedAttention,
)(MultiHeadedAttention(
self.self_attn = MultiHeadedAttention(
n_head,
d_model,
dropout_rate,
@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
group_size=attn_group_sizes,
))
)
self.conv = ConvModule(
d_model,
ext_pw_out_channel,
@ -441,26 +395,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
else:
raise NotImplementedError
def post_init(self, init_model_config):
pretrained_speech_encoder_path = init_model_config.get(
"pretrained_speech_encoder_path", None)
if pretrained_speech_encoder_path:
model_state = torch.load(pretrained_speech_encoder_path,
map_location="cpu")
encoder_state_dict = {}
for k, v in model_state.items():
if "encoder." in k:
tmp_k = k.replace("encoder.", "")
encoder_state_dict[tmp_k] = v
if hasattr(self, "encoder_embedding"):
del self.encoder_embedding
self.load_state_dict(encoder_state_dict)
if not hasattr(self, "encoder_embedding"):
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"])
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"])
def compute_lens_change(self, feature_lens):
"""feature_lens: int
@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
# avoid randomness when run evaluation or decoding
if self.training and np.random.rand() > 0.5:
# Either first or last chunk is not complete.
# If only the last one is not complete, EOS is not effective
chunk_start_idx = seq_len - chunk_start_idx
chunk_start_idx = chunk_start_idx[::-1]
chunk_start_idx = chunk_start_idx[:-1]
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
enc_streaming_mask = (adaptive_enc_mask(
seq_len, chunk_start_idx,
@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
self.num_blocks = num_blocks
self.num_lang = num_lang
self.kernel_size = kernel_size
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
self.embed)
self.replication_pad_for_subsample_embedding: bool = (
replication_pad_for_subsample_embedding)
assert (self.num_heads % attention_group_size == 0
), "attention_group_size must divide n_head"
self.num_heads_k = self.num_heads // attention_group_size
self.encoders = repeat(
num_blocks,
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
ConformerEncoderLayer, i)
(ConformerEncoderLayer(
self.encoders = MultiSequential(*[
ConformerEncoderLayer(
d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel=
depthwise_seperable_out_channel,
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads,
d_ffn=linear_units,
@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type,
activation_checkpointing=attn_checkpointing(
activation_checkpointing, i),
activation_checkpointing=activation_checkpointing,
export=export,
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size,
)),
)
) for _ in range(num_blocks)
])
self.extra_layer_output_idx = extra_layer_output_idx
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine
@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
return input_tensor, masks # , layer_emb
def gradient_checkpointing_enable(self):
pass
class WindowQformer(nn.Module):
"""Window-level Qformer"""
@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
if normalize_before else None)
self.window_size = window_size
self.gradient_checkpointing_enable = False
def enable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = True
def disable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = False
def forward(self, audio_embed, mask, embed_len=None):
"""forward decoder"""
@ -1111,20 +1022,10 @@ class WindowQformer(nn.Module):
# NT' x 1 x D
q = self.queries.expand(bsz * slen, -1, -1)
for layer in self.decoders:
if self.gradient_checkpointing_enable and self.training:
q = checkpoint(
layer.__call__,
q,
embed_chunk,
None,
mask,
use_reentrant=True,
)
else:
q = layer(tgt=q,
memory=embed_chunk,
tgt_mask=None,
memory_mask=mask)
q = layer(tgt=q,
memory=embed_chunk,
tgt_mask=None,
memory_mask=mask)
if self.after_norm is not None:
q = self.after_norm(q)
@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
hidden_size = (config.n_embd
if hasattr(config, "n_embd") else config.hidden_size)
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
else config.embed_pdrop)
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
audio_dim_out = (
@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
assert encoder_config is not None
self.encoder = ConformerEncoder(**encoder_config)
# fake initialization, create encoder_embedding layer only so that
# in decoding, all parameters can be loaded in
# from_pretrained_function in training, we do post init after
# from_pretrained function to make sure the correct initialization
self.encoder.post_init({})
audio_dim_out = encoder_config["attention_dim"]
n_mels = encoder_config["input_size"]
else:
@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
else:
self.conv_ds = None
enable_gradient_checkpointing = kwargs.get(
"enable_gradient_checkpointing", False)
if enable_gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
if self.qformer:
self.qformer.enable_gradient_checkpointing()
projection_cls = kwargs.get("projection_cls", "linear")
if projection_cls == "linear":
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
hidden_states.dtype).to(hidden_states.device))
idx += cnt
else:
if self.training:
# hidden_states[:, 0:img_set_tensor.shape[0]] =
# hidden_states[:, 0:img_set_tensor.shape[0]] +
# 0 * img_set_tensor.to(hidden_states.dtype)
# .to(hidden_states.device)
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
.to(hidden_states.device)
if self.drop is not None:
hidden_states = self.drop(hidden_states)
return hidden_states

View File

@ -5,14 +5,11 @@
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import math
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, checkpoint_wrapper, offload_wrapper)
class Block(nn.Module):
@ -873,10 +870,8 @@ class MeanVarianceNormLayer(nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.register_buffer("global_mean", torch.zeros(input_size))
self.register_buffer("global_invstd", torch.ones(input_size))
self.global_mean: Optional[Tensor]
self.global_invstd: Optional[Tensor]
self.global_mean = nn.Parameter(torch.zeros(input_size))
self.global_invstd = nn.Parameter(torch.ones(input_size))
def forward(self, input_: Tensor) -> Tensor:
"""MeanVarianceNormLayer Forward
@ -1023,21 +1018,10 @@ class CausalConv2D(nn.Conv2d):
self,
x,
):
if self.training:
x = F.pad(
x,
pad=(
self._left_padding,
self._right_padding,
self._left_padding,
self._right_padding,
),
)
else:
x = F.pad(
x,
pad=(self._left_padding, self._right_padding, 0, 0),
)
x = F.pad(
x,
pad=(self._left_padding, self._right_padding, 0, 0),
)
x = super().forward(x)
return x
@ -1840,68 +1824,6 @@ class MultiHeadedAttention(nn.Module):
return self.linear_out(x) # (batch, time1, d_model)
def validate_checkpointing_config(activation_checkpointing):
"""validate activation checkpointing configuration"""
if isinstance(activation_checkpointing, str):
assert activation_checkpointing in (
"",
"checkpoint",
"offload",
), "activation_checkpointing has to be a dict or a str in "\
"('', 'checkpoint', 'offload')."
elif isinstance(activation_checkpointing, dict):
assert activation_checkpointing.get("module", "transformer") in (
"transformer",
"attention",
), "module in activation_checkpointing has to be in "\
"('transformer', 'attention')."
else:
raise ValueError("activation_checkpointing has to be a str"\
" or dict.")
def embedding_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict], ) -> Callable:
"""return encoder embedding activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
enabled = activation_checkpointing.get("embed", False)
if enabled:
offloading = activation_checkpointing.get("offload", False)
if offloading:
return offload_wrapper
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", False) else CheckpointImpl.NO_REENTRANT)
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
def attn_checkpointing(activation_checkpointing: Union[str, Dict],
i) -> Union[str, Dict]:
"""return activation checkpointing config for attention layer"""
if isinstance(activation_checkpointing, str):
return ""
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
checkpointing_interval = activation_checkpointing.get("interval", 1)
if target_layer_cls == "attention" and i % checkpointing_interval == 0:
return activation_checkpointing
return ""
raise ValueError("Invalid activation_checkpointing config")
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential"""
@ -1913,17 +1835,6 @@ class MultiSequential(torch.nn.Sequential):
return args
def repeat(repeat_num, module_gen_fn):
"""repeat module N times
:param int repeat_num: repeat time
:param function module_gen_fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
def get_offset(input_layer: str, time_reduction: int):
"""Get an offset. We will use the offset for determining #frames of a
subsampled feature.