[VLM] Add TP support for Phi-4-MM (#14453)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
cb8bdfade2
commit
03fe18ae0f
@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
|
|||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_lora_rank=320,
|
max_lora_rank=320,
|
||||||
lora_extra_vocab_size=0,
|
lora_extra_vocab_size=0,
|
||||||
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
)
|
)
|
||||||
lora_request = LoRARequest("speech", 1, speech_lora_path)
|
lora_request = LoRARequest("speech", 1, speech_lora_path)
|
||||||
# To maintain code compatibility in this script, we add LoRA here.
|
# To maintain code compatibility in this script, we add LoRA here.
|
||||||
|
@ -15,7 +15,7 @@ from transformers import PretrainedConfig
|
|||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
InputContext)
|
InputContext)
|
||||||
from vllm.inputs.data import TokenInputs, token_inputs
|
from vllm.inputs.data import TokenInputs, token_inputs
|
||||||
@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||||
from .phi4mm_audio import AudioEmbedding
|
from .phi4mm_audio import AudioEmbedding
|
||||||
from .utils import maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
from .vision_siglip_navit import get_siglip_vision_model
|
from .vision_siglip_navit import get_siglip_vision_model
|
||||||
|
|
||||||
# <|endoftext10|> (see vocab.json in hf model)
|
# <|endoftext10|> (see vocab.json in hf model)
|
||||||
@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
|
|||||||
# n_embed or hidden_size
|
# n_embed or hidden_size
|
||||||
hidden_size = config.n_embd if hasattr(
|
hidden_size = config.n_embd if hasattr(
|
||||||
config, 'n_embd') else config.hidden_size
|
config, 'n_embd') else config.hidden_size
|
||||||
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
|
|
||||||
embd_drop = config.embd_pdrop if hasattr(
|
|
||||||
config, 'embd_pdrop') else config.embed_pdrop
|
|
||||||
self.drop = nn.Dropout(embd_drop)
|
|
||||||
else:
|
|
||||||
self.drop = None
|
|
||||||
|
|
||||||
# layer_idx to output the img features
|
# layer_idx to output the img features
|
||||||
if isinstance(config.img_processor, dict):
|
if isinstance(config.img_processor, dict):
|
||||||
@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_substr={
|
||||||
|
"base_layer.": "",
|
||||||
|
},
|
||||||
|
orig_to_new_prefix={
|
||||||
|
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
|
||||||
|
"embed_tokens_extend.audio_projection_for_vision.",
|
||||||
|
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
|
||||||
|
"embed_tokens_extend.audio_projection.",
|
||||||
|
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
|
||||||
|
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|
||||||
# Tensor/Pipeline parallel not supported for now.
|
# Tensor/Pipeline parallel not supported for now.
|
||||||
assert get_tensor_model_parallel_world_size(
|
|
||||||
) == 1, "tensor parallel is not supported"
|
|
||||||
assert get_pp_group(
|
assert get_pp_group(
|
||||||
).world_size == 1, "pipeline parallel is not supported"
|
).world_size == 1, "pipeline parallel is not supported"
|
||||||
|
|
||||||
@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
)
|
)
|
||||||
return merged_embeds
|
return merged_embeds
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
|
||||||
torch.Tensor]]) -> None:
|
|
||||||
weights = {name: weight for name, weight in weights}
|
|
||||||
adjusted_weights = {}
|
|
||||||
|
|
||||||
for name, weight in weights.items():
|
|
||||||
# NOTE vision-speech tasks use a separate projection layer
|
|
||||||
audio_proj_4v = \
|
|
||||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
|
|
||||||
if name.startswith(audio_proj_4v):
|
|
||||||
name = name.replace(
|
|
||||||
audio_proj_4v,
|
|
||||||
"embed_tokens_extend.audio_projection_for_vision")
|
|
||||||
|
|
||||||
name = (name.replace(
|
|
||||||
"model.embed_tokens_extend.audio_embed."\
|
|
||||||
"audio_projection.speech.",
|
|
||||||
"embed_tokens_extend.audio_projection.",
|
|
||||||
).replace(
|
|
||||||
"model.embed_tokens_extend.audio_embed.",
|
|
||||||
"embed_tokens_extend.",
|
|
||||||
).replace("model.embed_tokens_extend.image_embed.",
|
|
||||||
"vision_encoder."))
|
|
||||||
# NOTE: this is deal with LoRA injection, where `base_layer`
|
|
||||||
# remains as the original layer in the model
|
|
||||||
if name.endswith(".base_layer.weight"):
|
|
||||||
name = name.replace(".base_layer.weight", ".weight")
|
|
||||||
adjusted_weights[name] = weight
|
|
||||||
|
|
||||||
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
|
|
||||||
strict=False)
|
|
||||||
logger.debug("*** missing keys:")
|
|
||||||
for key in missing_keys:
|
|
||||||
logger.debug(key)
|
|
||||||
logger.debug("**** unexpected keys:")
|
|
||||||
for key in unexpected_keys:
|
|
||||||
logger.debug(key)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> None:
|
||||||
|
weights = ((name, data) for name, data in weights
|
||||||
|
if "lora" not in name)
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
def get_mm_mapping(self) -> MultiModelKeys:
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
"""
|
"""
|
||||||
Get the module prefix in multimodal models
|
Get the module prefix in multimodal models
|
||||||
|
@ -6,69 +6,26 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import abc
|
import abc
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from typing import List, Literal, Optional
|
||||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
|
CheckpointWrapper)
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
FullyShardedDataParallel)
|
FullyShardedDataParallel)
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.model_executor.models.phi4mm_utils import (
|
from vllm.model_executor.models.phi4mm_utils import (
|
||||||
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
|
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
|
||||||
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
|
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
|
||||||
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
|
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
|
||||||
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
|
|
||||||
|
|
||||||
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
||||||
|
|
||||||
|
|
||||||
def encoder_checkpoint_wrapper(
|
|
||||||
activation_checkpointing: Union[str, Dict],
|
|
||||||
layer_cls: type,
|
|
||||||
idx: int = 0,
|
|
||||||
) -> Callable:
|
|
||||||
"""return encoder activation checkpoint wrapper"""
|
|
||||||
validate_checkpointing_config(activation_checkpointing)
|
|
||||||
|
|
||||||
if isinstance(activation_checkpointing, str):
|
|
||||||
if activation_checkpointing:
|
|
||||||
if activation_checkpointing == "offload":
|
|
||||||
return offload_wrapper
|
|
||||||
return partial(checkpoint_wrapper)
|
|
||||||
return lambda x: x
|
|
||||||
|
|
||||||
if isinstance(activation_checkpointing, dict):
|
|
||||||
target_layer_cls = activation_checkpointing.get(
|
|
||||||
"module", "transformer")
|
|
||||||
if target_layer_cls.lower() == "transformer":
|
|
||||||
target_layer_cls = (
|
|
||||||
"EncoderLayer",
|
|
||||||
"ConformerEncoderLayer",
|
|
||||||
)
|
|
||||||
elif target_layer_cls.lower() == "attention":
|
|
||||||
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
|
|
||||||
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
|
||||||
offloading = activation_checkpointing.get("offload", False)
|
|
||||||
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
|
|
||||||
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
|
|
||||||
|
|
||||||
if (idx % checkpointing_interval == 0
|
|
||||||
and layer_cls.__name__ in target_layer_cls):
|
|
||||||
if offloading:
|
|
||||||
return offload_wrapper
|
|
||||||
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
|
||||||
return lambda x: x
|
|
||||||
|
|
||||||
raise ValueError("Invalid activation_checkpointing config")
|
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
"""ConformerEncoder Layer module.
|
"""ConformerEncoder Layer module.
|
||||||
for more details see conformer paper:
|
for more details see conformer paper:
|
||||||
@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
bias_in_glu=bias_in_glu,
|
bias_in_glu=bias_in_glu,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.self_attn = encoder_checkpoint_wrapper(
|
self.self_attn = MultiHeadedAttention(
|
||||||
activation_checkpointing,
|
|
||||||
MultiHeadedAttention,
|
|
||||||
)(MultiHeadedAttention(
|
|
||||||
n_head,
|
n_head,
|
||||||
d_model,
|
d_model,
|
||||||
dropout_rate,
|
dropout_rate,
|
||||||
@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
use_pt_scaled_dot_product_attention=
|
use_pt_scaled_dot_product_attention=
|
||||||
use_pt_scaled_dot_product_attention,
|
use_pt_scaled_dot_product_attention,
|
||||||
group_size=attn_group_sizes,
|
group_size=attn_group_sizes,
|
||||||
))
|
)
|
||||||
self.conv = ConvModule(
|
self.conv = ConvModule(
|
||||||
d_model,
|
d_model,
|
||||||
ext_pw_out_channel,
|
ext_pw_out_channel,
|
||||||
@ -441,24 +395,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def post_init(self, init_model_config):
|
|
||||||
|
|
||||||
pretrained_speech_encoder_path = init_model_config.get(
|
|
||||||
"pretrained_speech_encoder_path", None)
|
|
||||||
if pretrained_speech_encoder_path:
|
|
||||||
model_state = torch.load(pretrained_speech_encoder_path,
|
|
||||||
map_location="cpu")
|
|
||||||
encoder_state_dict = {}
|
|
||||||
for k, v in model_state.items():
|
|
||||||
if "encoder." in k:
|
|
||||||
tmp_k = k.replace("encoder.", "")
|
|
||||||
encoder_state_dict[tmp_k] = v
|
|
||||||
|
|
||||||
if hasattr(self, "encoder_embedding"):
|
|
||||||
del self.encoder_embedding
|
|
||||||
self.load_state_dict(encoder_state_dict)
|
|
||||||
|
|
||||||
if not hasattr(self, "encoder_embedding"):
|
|
||||||
self.encoder_embedding = MeanVarianceNormLayer(
|
self.encoder_embedding = MeanVarianceNormLayer(
|
||||||
self.encoder_embedding_config["input_size"])
|
self.encoder_embedding_config["input_size"])
|
||||||
|
|
||||||
@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
|||||||
# Create mask matrix for streaming
|
# Create mask matrix for streaming
|
||||||
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
||||||
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
||||||
# avoid randomness when run evaluation or decoding
|
|
||||||
if self.training and np.random.rand() > 0.5:
|
|
||||||
# Either first or last chunk is not complete.
|
|
||||||
# If only the last one is not complete, EOS is not effective
|
|
||||||
chunk_start_idx = seq_len - chunk_start_idx
|
|
||||||
chunk_start_idx = chunk_start_idx[::-1]
|
|
||||||
chunk_start_idx = chunk_start_idx[:-1]
|
|
||||||
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
|
|
||||||
|
|
||||||
enc_streaming_mask = (adaptive_enc_mask(
|
enc_streaming_mask = (adaptive_enc_mask(
|
||||||
seq_len, chunk_start_idx,
|
seq_len, chunk_start_idx,
|
||||||
@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.num_lang = num_lang
|
self.num_lang = num_lang
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
|
|
||||||
self.embed)
|
|
||||||
self.replication_pad_for_subsample_embedding: bool = (
|
self.replication_pad_for_subsample_embedding: bool = (
|
||||||
replication_pad_for_subsample_embedding)
|
replication_pad_for_subsample_embedding)
|
||||||
assert (self.num_heads % attention_group_size == 0
|
assert (self.num_heads % attention_group_size == 0
|
||||||
), "attention_group_size must divide n_head"
|
), "attention_group_size must divide n_head"
|
||||||
self.num_heads_k = self.num_heads // attention_group_size
|
self.num_heads_k = self.num_heads // attention_group_size
|
||||||
|
|
||||||
self.encoders = repeat(
|
self.encoders = MultiSequential(*[
|
||||||
num_blocks,
|
ConformerEncoderLayer(
|
||||||
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
|
|
||||||
ConformerEncoderLayer, i)
|
|
||||||
(ConformerEncoderLayer(
|
|
||||||
d_model=attention_dim,
|
d_model=attention_dim,
|
||||||
ext_pw_out_channel=ext_pw_out_channel,
|
ext_pw_out_channel=ext_pw_out_channel,
|
||||||
depthwise_seperable_out_channel=
|
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
||||||
depthwise_seperable_out_channel,
|
|
||||||
depthwise_multiplier=depthwise_multiplier,
|
depthwise_multiplier=depthwise_multiplier,
|
||||||
n_head=attention_heads,
|
n_head=attention_heads,
|
||||||
d_ffn=linear_units,
|
d_ffn=linear_units,
|
||||||
@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
bias_in_glu=bias_in_glu,
|
bias_in_glu=bias_in_glu,
|
||||||
linear_glu_in_convm=linear_glu_in_convm,
|
linear_glu_in_convm=linear_glu_in_convm,
|
||||||
attention_glu_type=attention_glu_type,
|
attention_glu_type=attention_glu_type,
|
||||||
activation_checkpointing=attn_checkpointing(
|
activation_checkpointing=activation_checkpointing,
|
||||||
activation_checkpointing, i),
|
|
||||||
export=export,
|
export=export,
|
||||||
use_pt_scaled_dot_product_attention=
|
use_pt_scaled_dot_product_attention=
|
||||||
use_pt_scaled_dot_product_attention,
|
use_pt_scaled_dot_product_attention,
|
||||||
attn_group_sizes=attention_group_size,
|
attn_group_sizes=attention_group_size,
|
||||||
)),
|
) for _ in range(num_blocks)
|
||||||
)
|
])
|
||||||
self.extra_layer_output_idx = extra_layer_output_idx
|
self.extra_layer_output_idx = extra_layer_output_idx
|
||||||
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
||||||
# Make a zeros scalar we can use in get_initial_state to determine
|
# Make a zeros scalar we can use in get_initial_state to determine
|
||||||
@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
|
|||||||
|
|
||||||
return input_tensor, masks # , layer_emb
|
return input_tensor, masks # , layer_emb
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class WindowQformer(nn.Module):
|
class WindowQformer(nn.Module):
|
||||||
"""Window-level Qformer"""
|
"""Window-level Qformer"""
|
||||||
@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
|
|||||||
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
|
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
|
||||||
if normalize_before else None)
|
if normalize_before else None)
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.gradient_checkpointing_enable = False
|
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
|
||||||
self.gradient_checkpointing_enable = True
|
|
||||||
|
|
||||||
def disable_gradient_checkpointing(self):
|
|
||||||
self.gradient_checkpointing_enable = False
|
|
||||||
|
|
||||||
def forward(self, audio_embed, mask, embed_len=None):
|
def forward(self, audio_embed, mask, embed_len=None):
|
||||||
"""forward decoder"""
|
"""forward decoder"""
|
||||||
@ -1111,16 +1022,6 @@ class WindowQformer(nn.Module):
|
|||||||
# NT' x 1 x D
|
# NT' x 1 x D
|
||||||
q = self.queries.expand(bsz * slen, -1, -1)
|
q = self.queries.expand(bsz * slen, -1, -1)
|
||||||
for layer in self.decoders:
|
for layer in self.decoders:
|
||||||
if self.gradient_checkpointing_enable and self.training:
|
|
||||||
q = checkpoint(
|
|
||||||
layer.__call__,
|
|
||||||
q,
|
|
||||||
embed_chunk,
|
|
||||||
None,
|
|
||||||
mask,
|
|
||||||
use_reentrant=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = layer(tgt=q,
|
q = layer(tgt=q,
|
||||||
memory=embed_chunk,
|
memory=embed_chunk,
|
||||||
tgt_mask=None,
|
tgt_mask=None,
|
||||||
@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
|
|||||||
hidden_size = (config.n_embd
|
hidden_size = (config.n_embd
|
||||||
if hasattr(config, "n_embd") else config.hidden_size)
|
if hasattr(config, "n_embd") else config.hidden_size)
|
||||||
|
|
||||||
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
|
|
||||||
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
|
|
||||||
else config.embed_pdrop)
|
|
||||||
self.drop = nn.Dropout(embd_drop)
|
|
||||||
else:
|
|
||||||
self.drop = None
|
|
||||||
|
|
||||||
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
||||||
|
|
||||||
audio_dim_out = (
|
audio_dim_out = (
|
||||||
@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
|
|||||||
assert encoder_config is not None
|
assert encoder_config is not None
|
||||||
self.encoder = ConformerEncoder(**encoder_config)
|
self.encoder = ConformerEncoder(**encoder_config)
|
||||||
|
|
||||||
# fake initialization, create encoder_embedding layer only so that
|
|
||||||
# in decoding, all parameters can be loaded in
|
|
||||||
# from_pretrained_function in training, we do post init after
|
|
||||||
# from_pretrained function to make sure the correct initialization
|
|
||||||
self.encoder.post_init({})
|
|
||||||
|
|
||||||
audio_dim_out = encoder_config["attention_dim"]
|
audio_dim_out = encoder_config["attention_dim"]
|
||||||
n_mels = encoder_config["input_size"]
|
n_mels = encoder_config["input_size"]
|
||||||
else:
|
else:
|
||||||
@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.conv_ds = None
|
self.conv_ds = None
|
||||||
|
|
||||||
enable_gradient_checkpointing = kwargs.get(
|
|
||||||
"enable_gradient_checkpointing", False)
|
|
||||||
if enable_gradient_checkpointing:
|
|
||||||
self.encoder.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
if self.qformer:
|
|
||||||
self.qformer.enable_gradient_checkpointing()
|
|
||||||
|
|
||||||
projection_cls = kwargs.get("projection_cls", "linear")
|
projection_cls = kwargs.get("projection_cls", "linear")
|
||||||
if projection_cls == "linear":
|
if projection_cls == "linear":
|
||||||
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
|
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
|
||||||
@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
|
|||||||
hidden_states.dtype).to(hidden_states.device))
|
hidden_states.dtype).to(hidden_states.device))
|
||||||
idx += cnt
|
idx += cnt
|
||||||
|
|
||||||
else:
|
|
||||||
if self.training:
|
|
||||||
# hidden_states[:, 0:img_set_tensor.shape[0]] =
|
|
||||||
# hidden_states[:, 0:img_set_tensor.shape[0]] +
|
|
||||||
# 0 * img_set_tensor.to(hidden_states.dtype)
|
|
||||||
# .to(hidden_states.device)
|
|
||||||
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
|
|
||||||
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
|
|
||||||
.to(hidden_states.device)
|
|
||||||
|
|
||||||
if self.drop is not None:
|
|
||||||
hidden_states = self.drop(hidden_states)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -5,14 +5,11 @@
|
|||||||
# but implemented by the Phi-Speech team
|
# but implemented by the Phi-Speech team
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from typing import Optional, Tuple, Union
|
||||||
from typing import Callable, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
||||||
CheckpointImpl, checkpoint_wrapper, offload_wrapper)
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
@ -873,10 +870,8 @@ class MeanVarianceNormLayer(nn.Module):
|
|||||||
def __init__(self, input_size):
|
def __init__(self, input_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.register_buffer("global_mean", torch.zeros(input_size))
|
self.global_mean = nn.Parameter(torch.zeros(input_size))
|
||||||
self.register_buffer("global_invstd", torch.ones(input_size))
|
self.global_invstd = nn.Parameter(torch.ones(input_size))
|
||||||
self.global_mean: Optional[Tensor]
|
|
||||||
self.global_invstd: Optional[Tensor]
|
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
"""MeanVarianceNormLayer Forward
|
"""MeanVarianceNormLayer Forward
|
||||||
@ -1023,17 +1018,6 @@ class CausalConv2D(nn.Conv2d):
|
|||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
):
|
):
|
||||||
if self.training:
|
|
||||||
x = F.pad(
|
|
||||||
x,
|
|
||||||
pad=(
|
|
||||||
self._left_padding,
|
|
||||||
self._right_padding,
|
|
||||||
self._left_padding,
|
|
||||||
self._right_padding,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = F.pad(
|
x = F.pad(
|
||||||
x,
|
x,
|
||||||
pad=(self._left_padding, self._right_padding, 0, 0),
|
pad=(self._left_padding, self._right_padding, 0, 0),
|
||||||
@ -1840,68 +1824,6 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
return self.linear_out(x) # (batch, time1, d_model)
|
return self.linear_out(x) # (batch, time1, d_model)
|
||||||
|
|
||||||
|
|
||||||
def validate_checkpointing_config(activation_checkpointing):
|
|
||||||
"""validate activation checkpointing configuration"""
|
|
||||||
if isinstance(activation_checkpointing, str):
|
|
||||||
assert activation_checkpointing in (
|
|
||||||
"",
|
|
||||||
"checkpoint",
|
|
||||||
"offload",
|
|
||||||
), "activation_checkpointing has to be a dict or a str in "\
|
|
||||||
"('', 'checkpoint', 'offload')."
|
|
||||||
elif isinstance(activation_checkpointing, dict):
|
|
||||||
assert activation_checkpointing.get("module", "transformer") in (
|
|
||||||
"transformer",
|
|
||||||
"attention",
|
|
||||||
), "module in activation_checkpointing has to be in "\
|
|
||||||
"('transformer', 'attention')."
|
|
||||||
else:
|
|
||||||
raise ValueError("activation_checkpointing has to be a str"\
|
|
||||||
" or dict.")
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_checkpoint_wrapper(
|
|
||||||
activation_checkpointing: Union[str, Dict], ) -> Callable:
|
|
||||||
"""return encoder embedding activation checkpoint wrapper"""
|
|
||||||
validate_checkpointing_config(activation_checkpointing)
|
|
||||||
|
|
||||||
if isinstance(activation_checkpointing, str):
|
|
||||||
if activation_checkpointing:
|
|
||||||
if activation_checkpointing == "offload":
|
|
||||||
return offload_wrapper
|
|
||||||
return partial(checkpoint_wrapper)
|
|
||||||
return lambda x: x
|
|
||||||
|
|
||||||
if isinstance(activation_checkpointing, dict):
|
|
||||||
enabled = activation_checkpointing.get("embed", False)
|
|
||||||
if enabled:
|
|
||||||
offloading = activation_checkpointing.get("offload", False)
|
|
||||||
if offloading:
|
|
||||||
return offload_wrapper
|
|
||||||
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
|
|
||||||
"reentrant", False) else CheckpointImpl.NO_REENTRANT)
|
|
||||||
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
|
||||||
return lambda x: x
|
|
||||||
raise ValueError("Invalid activation_checkpointing config")
|
|
||||||
|
|
||||||
|
|
||||||
def attn_checkpointing(activation_checkpointing: Union[str, Dict],
|
|
||||||
i) -> Union[str, Dict]:
|
|
||||||
"""return activation checkpointing config for attention layer"""
|
|
||||||
if isinstance(activation_checkpointing, str):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if isinstance(activation_checkpointing, dict):
|
|
||||||
target_layer_cls = activation_checkpointing.get(
|
|
||||||
"module", "transformer")
|
|
||||||
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
|
||||||
if target_layer_cls == "attention" and i % checkpointing_interval == 0:
|
|
||||||
return activation_checkpointing
|
|
||||||
return ""
|
|
||||||
|
|
||||||
raise ValueError("Invalid activation_checkpointing config")
|
|
||||||
|
|
||||||
|
|
||||||
class MultiSequential(torch.nn.Sequential):
|
class MultiSequential(torch.nn.Sequential):
|
||||||
"""Multi-input multi-output torch.nn.Sequential"""
|
"""Multi-input multi-output torch.nn.Sequential"""
|
||||||
|
|
||||||
@ -1913,17 +1835,6 @@ class MultiSequential(torch.nn.Sequential):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def repeat(repeat_num, module_gen_fn):
|
|
||||||
"""repeat module N times
|
|
||||||
|
|
||||||
:param int repeat_num: repeat time
|
|
||||||
:param function module_gen_fn: function to generate module
|
|
||||||
:return: repeated modules
|
|
||||||
:rtype: MultiSequential
|
|
||||||
"""
|
|
||||||
return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
|
|
||||||
|
|
||||||
|
|
||||||
def get_offset(input_layer: str, time_reduction: int):
|
def get_offset(input_layer: str, time_reduction: int):
|
||||||
"""Get an offset. We will use the offset for determining #frames of a
|
"""Get an offset. We will use the offset for determining #frames of a
|
||||||
subsampled feature.
|
subsampled feature.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user