[VLM] Add TP support for Phi-4-MM (#14453)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-08 21:57:14 +08:00 committed by GitHub
parent cb8bdfade2
commit 03fe18ae0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 295 deletions

View File

@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
lora_extra_vocab_size=0, lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count},
) )
lora_request = LoRARequest("speech", 1, speech_lora_path) lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here. # To maintain code compatibility in this script, we add LoRA here.

View File

@ -15,7 +15,7 @@ from transformers import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext) InputContext)
from vllm.inputs.data import TokenInputs, token_inputs from vllm.inputs.data import TokenInputs, token_inputs
@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal from .interfaces import SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding from .phi4mm_audio import AudioEmbedding
from .utils import maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision_siglip_navit import get_siglip_vision_model from .vision_siglip_navit import get_siglip_vision_model
# <|endoftext10|> (see vocab.json in hf model) # <|endoftext10|> (see vocab.json in hf model)
@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
# n_embed or hidden_size # n_embed or hidden_size
hidden_size = config.n_embd if hasattr( hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size config, 'n_embd') else config.hidden_size
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
embd_drop = config.embd_pdrop if hasattr(
config, 'embd_pdrop') else config.embed_pdrop
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# layer_idx to output the img features # layer_idx to output the img features
if isinstance(config.img_processor, dict): if isinstance(config.img_processor, dict):
@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
], ],
} }
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"base_layer.": "",
},
orig_to_new_prefix={
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
"embed_tokens_extend.audio_projection_for_vision.",
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
"embed_tokens_extend.audio_projection.",
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self.lora_config = lora_config self.lora_config = lora_config
# Tensor/Pipeline parallel not supported for now. # Tensor/Pipeline parallel not supported for now.
assert get_tensor_model_parallel_world_size(
) == 1, "tensor parallel is not supported"
assert get_pp_group( assert get_pp_group(
).world_size == 1, "pipeline parallel is not supported" ).world_size == 1, "pipeline parallel is not supported"
@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
) )
return merged_embeds return merged_embeds
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
# NOTE vision-speech tasks use a separate projection layer
audio_proj_4v = \
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
if name.startswith(audio_proj_4v):
name = name.replace(
audio_proj_4v,
"embed_tokens_extend.audio_projection_for_vision")
name = (name.replace(
"model.embed_tokens_extend.audio_embed."\
"audio_projection.speech.",
"embed_tokens_extend.audio_projection.",
).replace(
"model.embed_tokens_extend.audio_embed.",
"embed_tokens_extend.",
).replace("model.embed_tokens_extend.image_embed.",
"vision_encoder."))
# NOTE: this is deal with LoRA injection, where `base_layer`
# remains as the original layer in the model
if name.endswith(".base_layer.weight"):
name = name.replace(".base_layer.weight", ".weight")
adjusted_weights[name] = weight
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
logger.debug("*** missing keys:")
for key in missing_keys:
logger.debug(key)
logger.debug("**** unexpected keys:")
for key in unexpected_keys:
logger.debug(key)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
Get the module prefix in multimodal models Get the module prefix in multimodal models

View File

@ -6,69 +6,26 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import abc import abc
import math import math
from functools import partial from typing import List, Literal, Optional
from typing import Callable, Dict, List, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper) CheckpointWrapper)
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel) FullyShardedDataParallel)
from torch.utils.checkpoint import checkpoint
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.models.phi4mm_utils import ( from vllm.model_executor.models.phi4mm_utils import (
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias, MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper, T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
def encoder_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict],
layer_cls: type,
idx: int = 0,
) -> Callable:
"""return encoder activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
if target_layer_cls.lower() == "transformer":
target_layer_cls = (
"EncoderLayer",
"ConformerEncoderLayer",
)
elif target_layer_cls.lower() == "attention":
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
checkpointing_interval = activation_checkpointing.get("interval", 1)
offloading = activation_checkpointing.get("offload", False)
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
if (idx % checkpointing_interval == 0
and layer_cls.__name__ in target_layer_cls):
if offloading:
return offload_wrapper
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
"""ConformerEncoder Layer module. """ConformerEncoder Layer module.
for more details see conformer paper: for more details see conformer paper:
@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
bias_in_glu=bias_in_glu, bias_in_glu=bias_in_glu,
) )
self.self_attn = encoder_checkpoint_wrapper( self.self_attn = MultiHeadedAttention(
activation_checkpointing,
MultiHeadedAttention,
)(MultiHeadedAttention(
n_head, n_head,
d_model, d_model,
dropout_rate, dropout_rate,
@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
use_pt_scaled_dot_product_attention= use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention, use_pt_scaled_dot_product_attention,
group_size=attn_group_sizes, group_size=attn_group_sizes,
)) )
self.conv = ConvModule( self.conv = ConvModule(
d_model, d_model,
ext_pw_out_channel, ext_pw_out_channel,
@ -441,24 +395,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
else: else:
raise NotImplementedError raise NotImplementedError
def post_init(self, init_model_config):
pretrained_speech_encoder_path = init_model_config.get(
"pretrained_speech_encoder_path", None)
if pretrained_speech_encoder_path:
model_state = torch.load(pretrained_speech_encoder_path,
map_location="cpu")
encoder_state_dict = {}
for k, v in model_state.items():
if "encoder." in k:
tmp_k = k.replace("encoder.", "")
encoder_state_dict[tmp_k] = v
if hasattr(self, "encoder_embedding"):
del self.encoder_embedding
self.load_state_dict(encoder_state_dict)
if not hasattr(self, "encoder_embedding"):
self.encoder_embedding = MeanVarianceNormLayer( self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"]) self.encoder_embedding_config["input_size"])
@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
# Create mask matrix for streaming # Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....] # S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
# avoid randomness when run evaluation or decoding
if self.training and np.random.rand() > 0.5:
# Either first or last chunk is not complete.
# If only the last one is not complete, EOS is not effective
chunk_start_idx = seq_len - chunk_start_idx
chunk_start_idx = chunk_start_idx[::-1]
chunk_start_idx = chunk_start_idx[:-1]
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
enc_streaming_mask = (adaptive_enc_mask( enc_streaming_mask = (adaptive_enc_mask(
seq_len, chunk_start_idx, seq_len, chunk_start_idx,
@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.num_lang = num_lang self.num_lang = num_lang
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
self.embed)
self.replication_pad_for_subsample_embedding: bool = ( self.replication_pad_for_subsample_embedding: bool = (
replication_pad_for_subsample_embedding) replication_pad_for_subsample_embedding)
assert (self.num_heads % attention_group_size == 0 assert (self.num_heads % attention_group_size == 0
), "attention_group_size must divide n_head" ), "attention_group_size must divide n_head"
self.num_heads_k = self.num_heads // attention_group_size self.num_heads_k = self.num_heads // attention_group_size
self.encoders = repeat( self.encoders = MultiSequential(*[
num_blocks, ConformerEncoderLayer(
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
ConformerEncoderLayer, i)
(ConformerEncoderLayer(
d_model=attention_dim, d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel, ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel= depthwise_seperable_out_channel=depthwise_seperable_out_channel,
depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier, depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads, n_head=attention_heads,
d_ffn=linear_units, d_ffn=linear_units,
@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
bias_in_glu=bias_in_glu, bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm, linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type, attention_glu_type=attention_glu_type,
activation_checkpointing=attn_checkpointing( activation_checkpointing=activation_checkpointing,
activation_checkpointing, i),
export=export, export=export,
use_pt_scaled_dot_product_attention= use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention, use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size, attn_group_sizes=attention_group_size,
)), ) for _ in range(num_blocks)
) ])
self.extra_layer_output_idx = extra_layer_output_idx self.extra_layer_output_idx = extra_layer_output_idx
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine # Make a zeros scalar we can use in get_initial_state to determine
@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
return input_tensor, masks # , layer_emb return input_tensor, masks # , layer_emb
def gradient_checkpointing_enable(self):
pass
class WindowQformer(nn.Module): class WindowQformer(nn.Module):
"""Window-level Qformer""" """Window-level Qformer"""
@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
if normalize_before else None) if normalize_before else None)
self.window_size = window_size self.window_size = window_size
self.gradient_checkpointing_enable = False
def enable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = True
def disable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = False
def forward(self, audio_embed, mask, embed_len=None): def forward(self, audio_embed, mask, embed_len=None):
"""forward decoder""" """forward decoder"""
@ -1111,16 +1022,6 @@ class WindowQformer(nn.Module):
# NT' x 1 x D # NT' x 1 x D
q = self.queries.expand(bsz * slen, -1, -1) q = self.queries.expand(bsz * slen, -1, -1)
for layer in self.decoders: for layer in self.decoders:
if self.gradient_checkpointing_enable and self.training:
q = checkpoint(
layer.__call__,
q,
embed_chunk,
None,
mask,
use_reentrant=True,
)
else:
q = layer(tgt=q, q = layer(tgt=q,
memory=embed_chunk, memory=embed_chunk,
tgt_mask=None, tgt_mask=None,
@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
hidden_size = (config.n_embd hidden_size = (config.n_embd
if hasattr(config, "n_embd") else config.hidden_size) if hasattr(config, "n_embd") else config.hidden_size)
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
else config.embed_pdrop)
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# self.wte = nn.Embedding(config.vocab_size, hidden_size) # self.wte = nn.Embedding(config.vocab_size, hidden_size)
audio_dim_out = ( audio_dim_out = (
@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
assert encoder_config is not None assert encoder_config is not None
self.encoder = ConformerEncoder(**encoder_config) self.encoder = ConformerEncoder(**encoder_config)
# fake initialization, create encoder_embedding layer only so that
# in decoding, all parameters can be loaded in
# from_pretrained_function in training, we do post init after
# from_pretrained function to make sure the correct initialization
self.encoder.post_init({})
audio_dim_out = encoder_config["attention_dim"] audio_dim_out = encoder_config["attention_dim"]
n_mels = encoder_config["input_size"] n_mels = encoder_config["input_size"]
else: else:
@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
else: else:
self.conv_ds = None self.conv_ds = None
enable_gradient_checkpointing = kwargs.get(
"enable_gradient_checkpointing", False)
if enable_gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
if self.qformer:
self.qformer.enable_gradient_checkpointing()
projection_cls = kwargs.get("projection_cls", "linear") projection_cls = kwargs.get("projection_cls", "linear")
if projection_cls == "linear": if projection_cls == "linear":
self.audio_projection = nn.Linear(audio_dim_out, hidden_size) self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
hidden_states.dtype).to(hidden_states.device)) hidden_states.dtype).to(hidden_states.device))
idx += cnt idx += cnt
else:
if self.training:
# hidden_states[:, 0:img_set_tensor.shape[0]] =
# hidden_states[:, 0:img_set_tensor.shape[0]] +
# 0 * img_set_tensor.to(hidden_states.dtype)
# .to(hidden_states.device)
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
.to(hidden_states.device)
if self.drop is not None:
hidden_states = self.drop(hidden_states)
return hidden_states return hidden_states

View File

@ -5,14 +5,11 @@
# but implemented by the Phi-Speech team # but implemented by the Phi-Speech team
#!/usr/bin/env python3 #!/usr/bin/env python3
import math import math
from functools import partial from typing import Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, checkpoint_wrapper, offload_wrapper)
class Block(nn.Module): class Block(nn.Module):
@ -873,10 +870,8 @@ class MeanVarianceNormLayer(nn.Module):
def __init__(self, input_size): def __init__(self, input_size):
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size
self.register_buffer("global_mean", torch.zeros(input_size)) self.global_mean = nn.Parameter(torch.zeros(input_size))
self.register_buffer("global_invstd", torch.ones(input_size)) self.global_invstd = nn.Parameter(torch.ones(input_size))
self.global_mean: Optional[Tensor]
self.global_invstd: Optional[Tensor]
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
"""MeanVarianceNormLayer Forward """MeanVarianceNormLayer Forward
@ -1023,17 +1018,6 @@ class CausalConv2D(nn.Conv2d):
self, self,
x, x,
): ):
if self.training:
x = F.pad(
x,
pad=(
self._left_padding,
self._right_padding,
self._left_padding,
self._right_padding,
),
)
else:
x = F.pad( x = F.pad(
x, x,
pad=(self._left_padding, self._right_padding, 0, 0), pad=(self._left_padding, self._right_padding, 0, 0),
@ -1840,68 +1824,6 @@ class MultiHeadedAttention(nn.Module):
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
def validate_checkpointing_config(activation_checkpointing):
"""validate activation checkpointing configuration"""
if isinstance(activation_checkpointing, str):
assert activation_checkpointing in (
"",
"checkpoint",
"offload",
), "activation_checkpointing has to be a dict or a str in "\
"('', 'checkpoint', 'offload')."
elif isinstance(activation_checkpointing, dict):
assert activation_checkpointing.get("module", "transformer") in (
"transformer",
"attention",
), "module in activation_checkpointing has to be in "\
"('transformer', 'attention')."
else:
raise ValueError("activation_checkpointing has to be a str"\
" or dict.")
def embedding_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict], ) -> Callable:
"""return encoder embedding activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
enabled = activation_checkpointing.get("embed", False)
if enabled:
offloading = activation_checkpointing.get("offload", False)
if offloading:
return offload_wrapper
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", False) else CheckpointImpl.NO_REENTRANT)
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
def attn_checkpointing(activation_checkpointing: Union[str, Dict],
i) -> Union[str, Dict]:
"""return activation checkpointing config for attention layer"""
if isinstance(activation_checkpointing, str):
return ""
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
checkpointing_interval = activation_checkpointing.get("interval", 1)
if target_layer_cls == "attention" and i % checkpointing_interval == 0:
return activation_checkpointing
return ""
raise ValueError("Invalid activation_checkpointing config")
class MultiSequential(torch.nn.Sequential): class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential""" """Multi-input multi-output torch.nn.Sequential"""
@ -1913,17 +1835,6 @@ class MultiSequential(torch.nn.Sequential):
return args return args
def repeat(repeat_num, module_gen_fn):
"""repeat module N times
:param int repeat_num: repeat time
:param function module_gen_fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
def get_offset(input_layer: str, time_reduction: int): def get_offset(input_layer: str, time_reduction: int):
"""Get an offset. We will use the offset for determining #frames of a """Get an offset. We will use the offset for determining #frames of a
subsampled feature. subsampled feature.