[VLM] Add TP support for Phi-4-MM (#14453)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
cb8bdfade2
commit
03fe18ae0f
@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
lora_request = LoRARequest("speech", 1, speech_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
|
@ -15,7 +15,7 @@ from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.inputs.data import TokenInputs, token_inputs
|
||||
@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
from .phi4mm_audio import AudioEmbedding
|
||||
from .utils import maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
from .vision_siglip_navit import get_siglip_vision_model
|
||||
|
||||
# <|endoftext10|> (see vocab.json in hf model)
|
||||
@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
|
||||
embd_drop = config.embd_pdrop if hasattr(
|
||||
config, 'embd_pdrop') else config.embed_pdrop
|
||||
self.drop = nn.Dropout(embd_drop)
|
||||
else:
|
||||
self.drop = None
|
||||
|
||||
# layer_idx to output the img features
|
||||
if isinstance(config.img_processor, dict):
|
||||
@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"base_layer.": "",
|
||||
},
|
||||
orig_to_new_prefix={
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
|
||||
"embed_tokens_extend.audio_projection_for_vision.",
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
|
||||
"embed_tokens_extend.audio_projection.",
|
||||
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
|
||||
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
self.lora_config = lora_config
|
||||
|
||||
# Tensor/Pipeline parallel not supported for now.
|
||||
assert get_tensor_model_parallel_world_size(
|
||||
) == 1, "tensor parallel is not supported"
|
||||
assert get_pp_group(
|
||||
).world_size == 1, "pipeline parallel is not supported"
|
||||
|
||||
@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
)
|
||||
return merged_embeds
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
weights = {name: weight for name, weight in weights}
|
||||
adjusted_weights = {}
|
||||
|
||||
for name, weight in weights.items():
|
||||
# NOTE vision-speech tasks use a separate projection layer
|
||||
audio_proj_4v = \
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
|
||||
if name.startswith(audio_proj_4v):
|
||||
name = name.replace(
|
||||
audio_proj_4v,
|
||||
"embed_tokens_extend.audio_projection_for_vision")
|
||||
|
||||
name = (name.replace(
|
||||
"model.embed_tokens_extend.audio_embed."\
|
||||
"audio_projection.speech.",
|
||||
"embed_tokens_extend.audio_projection.",
|
||||
).replace(
|
||||
"model.embed_tokens_extend.audio_embed.",
|
||||
"embed_tokens_extend.",
|
||||
).replace("model.embed_tokens_extend.image_embed.",
|
||||
"vision_encoder."))
|
||||
# NOTE: this is deal with LoRA injection, where `base_layer`
|
||||
# remains as the original layer in the model
|
||||
if name.endswith(".base_layer.weight"):
|
||||
name = name.replace(".base_layer.weight", ".weight")
|
||||
adjusted_weights[name] = weight
|
||||
|
||||
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
|
||||
strict=False)
|
||||
logger.debug("*** missing keys:")
|
||||
for key in missing_keys:
|
||||
logger.debug(key)
|
||||
logger.debug("**** unexpected keys:")
|
||||
for key in unexpected_keys:
|
||||
logger.debug(key)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
weights = ((name, data) for name, data in weights
|
||||
if "lora" not in name)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
@ -1804,4 +1779,4 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
language_model="model.",
|
||||
connector=["audio_projection_for_vision", "audio_projection"],
|
||||
tower_model=["vision_encoder", "embed_tokens_extend"],
|
||||
)
|
||||
)
|
||||
|
@ -6,69 +6,26 @@
|
||||
#!/usr/bin/env python3
|
||||
import abc
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
|
||||
CheckpointWrapper)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullyShardedDataParallel)
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.models.phi4mm_utils import (
|
||||
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
|
||||
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
|
||||
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
|
||||
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
|
||||
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
|
||||
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
||||
|
||||
|
||||
def encoder_checkpoint_wrapper(
|
||||
activation_checkpointing: Union[str, Dict],
|
||||
layer_cls: type,
|
||||
idx: int = 0,
|
||||
) -> Callable:
|
||||
"""return encoder activation checkpoint wrapper"""
|
||||
validate_checkpointing_config(activation_checkpointing)
|
||||
|
||||
if isinstance(activation_checkpointing, str):
|
||||
if activation_checkpointing:
|
||||
if activation_checkpointing == "offload":
|
||||
return offload_wrapper
|
||||
return partial(checkpoint_wrapper)
|
||||
return lambda x: x
|
||||
|
||||
if isinstance(activation_checkpointing, dict):
|
||||
target_layer_cls = activation_checkpointing.get(
|
||||
"module", "transformer")
|
||||
if target_layer_cls.lower() == "transformer":
|
||||
target_layer_cls = (
|
||||
"EncoderLayer",
|
||||
"ConformerEncoderLayer",
|
||||
)
|
||||
elif target_layer_cls.lower() == "attention":
|
||||
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
|
||||
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
||||
offloading = activation_checkpointing.get("offload", False)
|
||||
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
|
||||
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
|
||||
|
||||
if (idx % checkpointing_interval == 0
|
||||
and layer_cls.__name__ in target_layer_cls):
|
||||
if offloading:
|
||||
return offload_wrapper
|
||||
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
||||
return lambda x: x
|
||||
|
||||
raise ValueError("Invalid activation_checkpointing config")
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""ConformerEncoder Layer module.
|
||||
for more details see conformer paper:
|
||||
@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
bias_in_glu=bias_in_glu,
|
||||
)
|
||||
|
||||
self.self_attn = encoder_checkpoint_wrapper(
|
||||
activation_checkpointing,
|
||||
MultiHeadedAttention,
|
||||
)(MultiHeadedAttention(
|
||||
self.self_attn = MultiHeadedAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
dropout_rate,
|
||||
@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
use_pt_scaled_dot_product_attention=
|
||||
use_pt_scaled_dot_product_attention,
|
||||
group_size=attn_group_sizes,
|
||||
))
|
||||
)
|
||||
self.conv = ConvModule(
|
||||
d_model,
|
||||
ext_pw_out_channel,
|
||||
@ -441,26 +395,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def post_init(self, init_model_config):
|
||||
|
||||
pretrained_speech_encoder_path = init_model_config.get(
|
||||
"pretrained_speech_encoder_path", None)
|
||||
if pretrained_speech_encoder_path:
|
||||
model_state = torch.load(pretrained_speech_encoder_path,
|
||||
map_location="cpu")
|
||||
encoder_state_dict = {}
|
||||
for k, v in model_state.items():
|
||||
if "encoder." in k:
|
||||
tmp_k = k.replace("encoder.", "")
|
||||
encoder_state_dict[tmp_k] = v
|
||||
|
||||
if hasattr(self, "encoder_embedding"):
|
||||
del self.encoder_embedding
|
||||
self.load_state_dict(encoder_state_dict)
|
||||
|
||||
if not hasattr(self, "encoder_embedding"):
|
||||
self.encoder_embedding = MeanVarianceNormLayer(
|
||||
self.encoder_embedding_config["input_size"])
|
||||
self.encoder_embedding = MeanVarianceNormLayer(
|
||||
self.encoder_embedding_config["input_size"])
|
||||
|
||||
def compute_lens_change(self, feature_lens):
|
||||
"""feature_lens: int
|
||||
@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
# Create mask matrix for streaming
|
||||
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
||||
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
||||
# avoid randomness when run evaluation or decoding
|
||||
if self.training and np.random.rand() > 0.5:
|
||||
# Either first or last chunk is not complete.
|
||||
# If only the last one is not complete, EOS is not effective
|
||||
chunk_start_idx = seq_len - chunk_start_idx
|
||||
chunk_start_idx = chunk_start_idx[::-1]
|
||||
chunk_start_idx = chunk_start_idx[:-1]
|
||||
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
|
||||
|
||||
enc_streaming_mask = (adaptive_enc_mask(
|
||||
seq_len, chunk_start_idx,
|
||||
@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
self.num_blocks = num_blocks
|
||||
self.num_lang = num_lang
|
||||
self.kernel_size = kernel_size
|
||||
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
|
||||
self.embed)
|
||||
self.replication_pad_for_subsample_embedding: bool = (
|
||||
replication_pad_for_subsample_embedding)
|
||||
assert (self.num_heads % attention_group_size == 0
|
||||
), "attention_group_size must divide n_head"
|
||||
self.num_heads_k = self.num_heads // attention_group_size
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
|
||||
ConformerEncoderLayer, i)
|
||||
(ConformerEncoderLayer(
|
||||
self.encoders = MultiSequential(*[
|
||||
ConformerEncoderLayer(
|
||||
d_model=attention_dim,
|
||||
ext_pw_out_channel=ext_pw_out_channel,
|
||||
depthwise_seperable_out_channel=
|
||||
depthwise_seperable_out_channel,
|
||||
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
||||
depthwise_multiplier=depthwise_multiplier,
|
||||
n_head=attention_heads,
|
||||
d_ffn=linear_units,
|
||||
@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
bias_in_glu=bias_in_glu,
|
||||
linear_glu_in_convm=linear_glu_in_convm,
|
||||
attention_glu_type=attention_glu_type,
|
||||
activation_checkpointing=attn_checkpointing(
|
||||
activation_checkpointing, i),
|
||||
activation_checkpointing=activation_checkpointing,
|
||||
export=export,
|
||||
use_pt_scaled_dot_product_attention=
|
||||
use_pt_scaled_dot_product_attention,
|
||||
attn_group_sizes=attention_group_size,
|
||||
)),
|
||||
)
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
self.extra_layer_output_idx = extra_layer_output_idx
|
||||
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
||||
# Make a zeros scalar we can use in get_initial_state to determine
|
||||
@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
|
||||
return input_tensor, masks # , layer_emb
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
pass
|
||||
|
||||
|
||||
class WindowQformer(nn.Module):
|
||||
"""Window-level Qformer"""
|
||||
@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
|
||||
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
|
||||
if normalize_before else None)
|
||||
self.window_size = window_size
|
||||
self.gradient_checkpointing_enable = False
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing_enable = True
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing_enable = False
|
||||
|
||||
def forward(self, audio_embed, mask, embed_len=None):
|
||||
"""forward decoder"""
|
||||
@ -1111,20 +1022,10 @@ class WindowQformer(nn.Module):
|
||||
# NT' x 1 x D
|
||||
q = self.queries.expand(bsz * slen, -1, -1)
|
||||
for layer in self.decoders:
|
||||
if self.gradient_checkpointing_enable and self.training:
|
||||
q = checkpoint(
|
||||
layer.__call__,
|
||||
q,
|
||||
embed_chunk,
|
||||
None,
|
||||
mask,
|
||||
use_reentrant=True,
|
||||
)
|
||||
else:
|
||||
q = layer(tgt=q,
|
||||
memory=embed_chunk,
|
||||
tgt_mask=None,
|
||||
memory_mask=mask)
|
||||
q = layer(tgt=q,
|
||||
memory=embed_chunk,
|
||||
tgt_mask=None,
|
||||
memory_mask=mask)
|
||||
|
||||
if self.after_norm is not None:
|
||||
q = self.after_norm(q)
|
||||
@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
|
||||
hidden_size = (config.n_embd
|
||||
if hasattr(config, "n_embd") else config.hidden_size)
|
||||
|
||||
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
|
||||
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
|
||||
else config.embed_pdrop)
|
||||
self.drop = nn.Dropout(embd_drop)
|
||||
else:
|
||||
self.drop = None
|
||||
|
||||
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
||||
|
||||
audio_dim_out = (
|
||||
@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
|
||||
assert encoder_config is not None
|
||||
self.encoder = ConformerEncoder(**encoder_config)
|
||||
|
||||
# fake initialization, create encoder_embedding layer only so that
|
||||
# in decoding, all parameters can be loaded in
|
||||
# from_pretrained_function in training, we do post init after
|
||||
# from_pretrained function to make sure the correct initialization
|
||||
self.encoder.post_init({})
|
||||
|
||||
audio_dim_out = encoder_config["attention_dim"]
|
||||
n_mels = encoder_config["input_size"]
|
||||
else:
|
||||
@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
|
||||
else:
|
||||
self.conv_ds = None
|
||||
|
||||
enable_gradient_checkpointing = kwargs.get(
|
||||
"enable_gradient_checkpointing", False)
|
||||
if enable_gradient_checkpointing:
|
||||
self.encoder.gradient_checkpointing_enable()
|
||||
|
||||
if self.qformer:
|
||||
self.qformer.enable_gradient_checkpointing()
|
||||
|
||||
projection_cls = kwargs.get("projection_cls", "linear")
|
||||
if projection_cls == "linear":
|
||||
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
|
||||
@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
|
||||
hidden_states.dtype).to(hidden_states.device))
|
||||
idx += cnt
|
||||
|
||||
else:
|
||||
if self.training:
|
||||
# hidden_states[:, 0:img_set_tensor.shape[0]] =
|
||||
# hidden_states[:, 0:img_set_tensor.shape[0]] +
|
||||
# 0 * img_set_tensor.to(hidden_states.dtype)
|
||||
# .to(hidden_states.device)
|
||||
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
|
||||
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
|
||||
.to(hidden_states.device)
|
||||
|
||||
if self.drop is not None:
|
||||
hidden_states = self.drop(hidden_states)
|
||||
return hidden_states
|
||||
|
@ -5,14 +5,11 @@
|
||||
# but implemented by the Phi-Speech team
|
||||
#!/usr/bin/env python3
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointImpl, checkpoint_wrapper, offload_wrapper)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
@ -873,10 +870,8 @@ class MeanVarianceNormLayer(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.register_buffer("global_mean", torch.zeros(input_size))
|
||||
self.register_buffer("global_invstd", torch.ones(input_size))
|
||||
self.global_mean: Optional[Tensor]
|
||||
self.global_invstd: Optional[Tensor]
|
||||
self.global_mean = nn.Parameter(torch.zeros(input_size))
|
||||
self.global_invstd = nn.Parameter(torch.ones(input_size))
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
"""MeanVarianceNormLayer Forward
|
||||
@ -1023,21 +1018,10 @@ class CausalConv2D(nn.Conv2d):
|
||||
self,
|
||||
x,
|
||||
):
|
||||
if self.training:
|
||||
x = F.pad(
|
||||
x,
|
||||
pad=(
|
||||
self._left_padding,
|
||||
self._right_padding,
|
||||
self._left_padding,
|
||||
self._right_padding,
|
||||
),
|
||||
)
|
||||
else:
|
||||
x = F.pad(
|
||||
x,
|
||||
pad=(self._left_padding, self._right_padding, 0, 0),
|
||||
)
|
||||
x = F.pad(
|
||||
x,
|
||||
pad=(self._left_padding, self._right_padding, 0, 0),
|
||||
)
|
||||
x = super().forward(x)
|
||||
return x
|
||||
|
||||
@ -1840,68 +1824,6 @@ class MultiHeadedAttention(nn.Module):
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
|
||||
def validate_checkpointing_config(activation_checkpointing):
|
||||
"""validate activation checkpointing configuration"""
|
||||
if isinstance(activation_checkpointing, str):
|
||||
assert activation_checkpointing in (
|
||||
"",
|
||||
"checkpoint",
|
||||
"offload",
|
||||
), "activation_checkpointing has to be a dict or a str in "\
|
||||
"('', 'checkpoint', 'offload')."
|
||||
elif isinstance(activation_checkpointing, dict):
|
||||
assert activation_checkpointing.get("module", "transformer") in (
|
||||
"transformer",
|
||||
"attention",
|
||||
), "module in activation_checkpointing has to be in "\
|
||||
"('transformer', 'attention')."
|
||||
else:
|
||||
raise ValueError("activation_checkpointing has to be a str"\
|
||||
" or dict.")
|
||||
|
||||
|
||||
def embedding_checkpoint_wrapper(
|
||||
activation_checkpointing: Union[str, Dict], ) -> Callable:
|
||||
"""return encoder embedding activation checkpoint wrapper"""
|
||||
validate_checkpointing_config(activation_checkpointing)
|
||||
|
||||
if isinstance(activation_checkpointing, str):
|
||||
if activation_checkpointing:
|
||||
if activation_checkpointing == "offload":
|
||||
return offload_wrapper
|
||||
return partial(checkpoint_wrapper)
|
||||
return lambda x: x
|
||||
|
||||
if isinstance(activation_checkpointing, dict):
|
||||
enabled = activation_checkpointing.get("embed", False)
|
||||
if enabled:
|
||||
offloading = activation_checkpointing.get("offload", False)
|
||||
if offloading:
|
||||
return offload_wrapper
|
||||
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
|
||||
"reentrant", False) else CheckpointImpl.NO_REENTRANT)
|
||||
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
||||
return lambda x: x
|
||||
raise ValueError("Invalid activation_checkpointing config")
|
||||
|
||||
|
||||
def attn_checkpointing(activation_checkpointing: Union[str, Dict],
|
||||
i) -> Union[str, Dict]:
|
||||
"""return activation checkpointing config for attention layer"""
|
||||
if isinstance(activation_checkpointing, str):
|
||||
return ""
|
||||
|
||||
if isinstance(activation_checkpointing, dict):
|
||||
target_layer_cls = activation_checkpointing.get(
|
||||
"module", "transformer")
|
||||
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
||||
if target_layer_cls == "attention" and i % checkpointing_interval == 0:
|
||||
return activation_checkpointing
|
||||
return ""
|
||||
|
||||
raise ValueError("Invalid activation_checkpointing config")
|
||||
|
||||
|
||||
class MultiSequential(torch.nn.Sequential):
|
||||
"""Multi-input multi-output torch.nn.Sequential"""
|
||||
|
||||
@ -1913,17 +1835,6 @@ class MultiSequential(torch.nn.Sequential):
|
||||
return args
|
||||
|
||||
|
||||
def repeat(repeat_num, module_gen_fn):
|
||||
"""repeat module N times
|
||||
|
||||
:param int repeat_num: repeat time
|
||||
:param function module_gen_fn: function to generate module
|
||||
:return: repeated modules
|
||||
:rtype: MultiSequential
|
||||
"""
|
||||
return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
|
||||
|
||||
|
||||
def get_offset(input_layer: str, time_reduction: int):
|
||||
"""Get an offset. We will use the offset for determining #frames of a
|
||||
subsampled feature.
|
||||
|
Loading…
x
Reference in New Issue
Block a user