diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 4aa23321..293b9fdd 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int): enable_lora=True, max_lora_rank=320, lora_extra_vocab_size=0, + limit_mm_per_prompt={"audio": audio_count}, ) lora_request = LoRARequest("speech", 1, speech_lora_path) # To maintain code compatibility in this script, we add LoRA here. diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 6f5ea5af..89abfc59 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -15,7 +15,7 @@ from transformers import PretrainedConfig from transformers.utils import logging from vllm.config import VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.inputs.data import TokenInputs, token_inputs @@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .vision_siglip_navit import get_siglip_vision_model # <|endoftext10|> (see vocab.json in hf model) @@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module): # n_embed or hidden_size hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size - if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): - embd_drop = config.embd_pdrop if hasattr( - config, 'embd_pdrop') else config.embed_pdrop - self.drop = nn.Dropout(embd_drop) - else: - self.drop = None # layer_idx to output the img features if isinstance(config.img_processor, dict): @@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ], } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "base_layer.": "", + }, + orig_to_new_prefix={ + "model.embed_tokens_extend.audio_embed.audio_projection.vision.": + "embed_tokens_extend.audio_projection_for_vision.", + "model.embed_tokens_extend.audio_embed.audio_projection.speech.": + "embed_tokens_extend.audio_projection.", + "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", + "model.embed_tokens_extend.image_embed.": "vision_encoder.", + }, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. - assert get_tensor_model_parallel_world_size( - ) == 1, "tensor parallel is not supported" assert get_pp_group( ).world_size == 1, "pipeline parallel is not supported" @@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ) return merged_embeds - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> None: - weights = {name: weight for name, weight in weights} - adjusted_weights = {} - - for name, weight in weights.items(): - # NOTE vision-speech tasks use a separate projection layer - audio_proj_4v = \ - "model.embed_tokens_extend.audio_embed.audio_projection.vision" - if name.startswith(audio_proj_4v): - name = name.replace( - audio_proj_4v, - "embed_tokens_extend.audio_projection_for_vision") - - name = (name.replace( - "model.embed_tokens_extend.audio_embed."\ - "audio_projection.speech.", - "embed_tokens_extend.audio_projection.", - ).replace( - "model.embed_tokens_extend.audio_embed.", - "embed_tokens_extend.", - ).replace("model.embed_tokens_extend.image_embed.", - "vision_encoder.")) - # NOTE: this is deal with LoRA injection, where `base_layer` - # remains as the original layer in the model - if name.endswith(".base_layer.weight"): - name = name.replace(".base_layer.weight", ".weight") - adjusted_weights[name] = weight - - missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, - strict=False) - logger.debug("*** missing keys:") - for key in missing_keys: - logger.debug(key) - logger.debug("**** unexpected keys:") - for key in unexpected_keys: - logger.debug(key) - def forward( self, input_ids: torch.Tensor, @@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> None: + weights = ((name, data) for name, data in weights + if "lora" not in name) + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models @@ -1804,4 +1779,4 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): language_model="model.", connector=["audio_projection_for_vision", "audio_projection"], tower_model=["vision_encoder", "embed_tokens_extend"], - ) \ No newline at end of file + ) diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index f9d4881c..db90848f 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -6,69 +6,26 @@ #!/usr/bin/env python3 import abc import math -from functools import partial -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper) + CheckpointWrapper) from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel) -from torch.utils.checkpoint import checkpoint from transformers import PretrainedConfig from vllm.model_executor.models.phi4mm_utils import ( AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, - MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias, - adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper, - get_offset, repeat, unfold_tensor, validate_checkpointing_config) + MultiHeadedAttention, MultiSequential, NemoConvSubsampling, + T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor) _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> -def encoder_checkpoint_wrapper( - activation_checkpointing: Union[str, Dict], - layer_cls: type, - idx: int = 0, -) -> Callable: - """return encoder activation checkpoint wrapper""" - validate_checkpointing_config(activation_checkpointing) - - if isinstance(activation_checkpointing, str): - if activation_checkpointing: - if activation_checkpointing == "offload": - return offload_wrapper - return partial(checkpoint_wrapper) - return lambda x: x - - if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get( - "module", "transformer") - if target_layer_cls.lower() == "transformer": - target_layer_cls = ( - "EncoderLayer", - "ConformerEncoderLayer", - ) - elif target_layer_cls.lower() == "attention": - target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") - checkpointing_interval = activation_checkpointing.get("interval", 1) - offloading = activation_checkpointing.get("offload", False) - impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get( - "reentrant", True) else CheckpointImpl.NO_REENTRANT) - - if (idx % checkpointing_interval == 0 - and layer_cls.__name__ in target_layer_cls): - if offloading: - return offload_wrapper - return partial(checkpoint_wrapper, checkpoint_impl=impl) - return lambda x: x - - raise ValueError("Invalid activation_checkpointing config") - - class ConformerEncoderLayer(nn.Module): """ConformerEncoder Layer module. for more details see conformer paper: @@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module): bias_in_glu=bias_in_glu, ) - self.self_attn = encoder_checkpoint_wrapper( - activation_checkpointing, - MultiHeadedAttention, - )(MultiHeadedAttention( + self.self_attn = MultiHeadedAttention( n_head, d_model, dropout_rate, @@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module): use_pt_scaled_dot_product_attention= use_pt_scaled_dot_product_attention, group_size=attn_group_sizes, - )) + ) self.conv = ConvModule( d_model, ext_pw_out_channel, @@ -441,26 +395,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module): else: raise NotImplementedError - def post_init(self, init_model_config): - - pretrained_speech_encoder_path = init_model_config.get( - "pretrained_speech_encoder_path", None) - if pretrained_speech_encoder_path: - model_state = torch.load(pretrained_speech_encoder_path, - map_location="cpu") - encoder_state_dict = {} - for k, v in model_state.items(): - if "encoder." in k: - tmp_k = k.replace("encoder.", "") - encoder_state_dict[tmp_k] = v - - if hasattr(self, "encoder_embedding"): - del self.encoder_embedding - self.load_state_dict(encoder_state_dict) - - if not hasattr(self, "encoder_embedding"): - self.encoder_embedding = MeanVarianceNormLayer( - self.encoder_embedding_config["input_size"]) + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"]) def compute_lens_change(self, feature_lens): """feature_lens: int @@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module): # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) - # avoid randomness when run evaluation or decoding - if self.training and np.random.rand() > 0.5: - # Either first or last chunk is not complete. - # If only the last one is not complete, EOS is not effective - chunk_start_idx = seq_len - chunk_start_idx - chunk_start_idx = chunk_start_idx[::-1] - chunk_start_idx = chunk_start_idx[:-1] - chunk_start_idx = np.insert(chunk_start_idx, 0, 0) enc_streaming_mask = (adaptive_enc_mask( seq_len, chunk_start_idx, @@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase): self.num_blocks = num_blocks self.num_lang = num_lang self.kernel_size = kernel_size - self.embed = embedding_checkpoint_wrapper(activation_checkpointing)( - self.embed) self.replication_pad_for_subsample_embedding: bool = ( replication_pad_for_subsample_embedding) assert (self.num_heads % attention_group_size == 0 ), "attention_group_size must divide n_head" self.num_heads_k = self.num_heads // attention_group_size - self.encoders = repeat( - num_blocks, - lambda i: encoder_checkpoint_wrapper(activation_checkpointing, - ConformerEncoderLayer, i) - (ConformerEncoderLayer( + self.encoders = MultiSequential(*[ + ConformerEncoderLayer( d_model=attention_dim, ext_pw_out_channel=ext_pw_out_channel, - depthwise_seperable_out_channel= - depthwise_seperable_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, depthwise_multiplier=depthwise_multiplier, n_head=attention_heads, d_ffn=linear_units, @@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase): bias_in_glu=bias_in_glu, linear_glu_in_convm=linear_glu_in_convm, attention_glu_type=attention_glu_type, - activation_checkpointing=attn_checkpointing( - activation_checkpointing, i), + activation_checkpointing=activation_checkpointing, export=export, use_pt_scaled_dot_product_attention= use_pt_scaled_dot_product_attention, attn_group_sizes=attention_group_size, - )), - ) + ) for _ in range(num_blocks) + ]) self.extra_layer_output_idx = extra_layer_output_idx self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs # Make a zeros scalar we can use in get_initial_state to determine @@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase): return input_tensor, masks # , layer_emb - def gradient_checkpointing_enable(self): - pass - class WindowQformer(nn.Module): """Window-level Qformer""" @@ -1077,13 +995,6 @@ class WindowQformer(nn.Module): self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None) self.window_size = window_size - self.gradient_checkpointing_enable = False - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing_enable = True - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing_enable = False def forward(self, audio_embed, mask, embed_len=None): """forward decoder""" @@ -1111,20 +1022,10 @@ class WindowQformer(nn.Module): # NT' x 1 x D q = self.queries.expand(bsz * slen, -1, -1) for layer in self.decoders: - if self.gradient_checkpointing_enable and self.training: - q = checkpoint( - layer.__call__, - q, - embed_chunk, - None, - mask, - use_reentrant=True, - ) - else: - q = layer(tgt=q, - memory=embed_chunk, - tgt_mask=None, - memory_mask=mask) + q = layer(tgt=q, + memory=embed_chunk, + tgt_mask=None, + memory_mask=mask) if self.after_norm is not None: q = self.after_norm(q) @@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module): hidden_size = (config.n_embd if hasattr(config, "n_embd") else config.hidden_size) - if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): - embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop") - else config.embed_pdrop) - self.drop = nn.Dropout(embd_drop) - else: - self.drop = None - # self.wte = nn.Embedding(config.vocab_size, hidden_size) audio_dim_out = ( @@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module): assert encoder_config is not None self.encoder = ConformerEncoder(**encoder_config) - # fake initialization, create encoder_embedding layer only so that - # in decoding, all parameters can be loaded in - # from_pretrained_function in training, we do post init after - # from_pretrained function to make sure the correct initialization - self.encoder.post_init({}) - audio_dim_out = encoder_config["attention_dim"] n_mels = encoder_config["input_size"] else: @@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module): else: self.conv_ds = None - enable_gradient_checkpointing = kwargs.get( - "enable_gradient_checkpointing", False) - if enable_gradient_checkpointing: - self.encoder.gradient_checkpointing_enable() - - if self.qformer: - self.qformer.enable_gradient_checkpointing() - projection_cls = kwargs.get("projection_cls", "linear") if projection_cls == "linear": self.audio_projection = nn.Linear(audio_dim_out, hidden_size) @@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module): hidden_states.dtype).to(hidden_states.device)) idx += cnt - else: - if self.training: - # hidden_states[:, 0:img_set_tensor.shape[0]] = - # hidden_states[:, 0:img_set_tensor.shape[0]] + - # 0 * img_set_tensor.to(hidden_states.dtype) - # .to(hidden_states.device) - hidden_states[:, 0:1] = hidden_states[:, 0:1] + \ - 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\ - .to(hidden_states.device) - - if self.drop is not None: - hidden_states = self.drop(hidden_states) return hidden_states diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 16b62c60..ca00207a 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -5,14 +5,11 @@ # but implemented by the Phi-Speech team #!/usr/bin/env python3 import math -from functools import partial -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, checkpoint_wrapper, offload_wrapper) class Block(nn.Module): @@ -873,10 +870,8 @@ class MeanVarianceNormLayer(nn.Module): def __init__(self, input_size): super().__init__() self.input_size = input_size - self.register_buffer("global_mean", torch.zeros(input_size)) - self.register_buffer("global_invstd", torch.ones(input_size)) - self.global_mean: Optional[Tensor] - self.global_invstd: Optional[Tensor] + self.global_mean = nn.Parameter(torch.zeros(input_size)) + self.global_invstd = nn.Parameter(torch.ones(input_size)) def forward(self, input_: Tensor) -> Tensor: """MeanVarianceNormLayer Forward @@ -1023,21 +1018,10 @@ class CausalConv2D(nn.Conv2d): self, x, ): - if self.training: - x = F.pad( - x, - pad=( - self._left_padding, - self._right_padding, - self._left_padding, - self._right_padding, - ), - ) - else: - x = F.pad( - x, - pad=(self._left_padding, self._right_padding, 0, 0), - ) + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) x = super().forward(x) return x @@ -1840,68 +1824,6 @@ class MultiHeadedAttention(nn.Module): return self.linear_out(x) # (batch, time1, d_model) -def validate_checkpointing_config(activation_checkpointing): - """validate activation checkpointing configuration""" - if isinstance(activation_checkpointing, str): - assert activation_checkpointing in ( - "", - "checkpoint", - "offload", - ), "activation_checkpointing has to be a dict or a str in "\ - "('', 'checkpoint', 'offload')." - elif isinstance(activation_checkpointing, dict): - assert activation_checkpointing.get("module", "transformer") in ( - "transformer", - "attention", - ), "module in activation_checkpointing has to be in "\ - "('transformer', 'attention')." - else: - raise ValueError("activation_checkpointing has to be a str"\ - " or dict.") - - -def embedding_checkpoint_wrapper( - activation_checkpointing: Union[str, Dict], ) -> Callable: - """return encoder embedding activation checkpoint wrapper""" - validate_checkpointing_config(activation_checkpointing) - - if isinstance(activation_checkpointing, str): - if activation_checkpointing: - if activation_checkpointing == "offload": - return offload_wrapper - return partial(checkpoint_wrapper) - return lambda x: x - - if isinstance(activation_checkpointing, dict): - enabled = activation_checkpointing.get("embed", False) - if enabled: - offloading = activation_checkpointing.get("offload", False) - if offloading: - return offload_wrapper - impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get( - "reentrant", False) else CheckpointImpl.NO_REENTRANT) - return partial(checkpoint_wrapper, checkpoint_impl=impl) - return lambda x: x - raise ValueError("Invalid activation_checkpointing config") - - -def attn_checkpointing(activation_checkpointing: Union[str, Dict], - i) -> Union[str, Dict]: - """return activation checkpointing config for attention layer""" - if isinstance(activation_checkpointing, str): - return "" - - if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get( - "module", "transformer") - checkpointing_interval = activation_checkpointing.get("interval", 1) - if target_layer_cls == "attention" and i % checkpointing_interval == 0: - return activation_checkpointing - return "" - - raise ValueError("Invalid activation_checkpointing config") - - class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential""" @@ -1913,17 +1835,6 @@ class MultiSequential(torch.nn.Sequential): return args -def repeat(repeat_num, module_gen_fn): - """repeat module N times - - :param int repeat_num: repeat time - :param function module_gen_fn: function to generate module - :return: repeated modules - :rtype: MultiSequential - """ - return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) - - def get_offset(input_layer: str, time_reduction: int): """Get an offset. We will use the offset for determining #frames of a subsampled feature.