diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py new file mode 100644 index 00000000..b1c46190 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +from dataclasses import dataclass + +import torch + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) +from vllm.attention.backends.xformers import XFormersMetadata + + +@dataclass +class Mamba2Metadata: + has_prefill: bool + + has_initial_states: torch.Tensor + prep_initial_states: bool + + chunk_size: int + seq_idx: torch.Tensor + chunk_indices: torch.Tensor + chunk_offsets: torch.Tensor + + +def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): + + # convert seq_idx to chunk indices and offsets + # - derive the cu_seqlens + _, cu_seqlens = torch.where(seq_idx.diff()) + cu_seqlens += 1 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) + chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + + cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +def prepare_mamba2_metadata( + chunk_size: int, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, +) -> Mamba2Metadata: + + # Need flags to indicate if there are initial states + # currently we really only support the FlashAttention backend + has_initial_states = None + prep_initial_states = False + if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + # precompute flag to avoid device syncs later in mamba2 forwards + prep_initial_states = torch.any(has_initial_states).item() + + has_prefill = attn_metadata.num_prefills > 0 + + seq_idx = None + chunk_indices, chunk_offsets = None, None + if has_prefill: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + # compute metadata for chunked prefill. + # actually this is only needed if there are initial states, + # but this is determinable only from attention metadata yet + # unavailable from the top-level model forward. Rather than + # complicating things to extract said metadata, we simply just + # compute them once at the top level model forward and reuse + # them in mamba layers. If not needed, they will be ignored + # inside mamba kernels. + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + + return Mamba2Metadata(has_prefill=has_prefill, + has_initial_states=has_initial_states, + prep_initial_states=prep_initial_states, + chunk_size=chunk_size, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index d7a45bc5..d459c93a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -6,10 +6,6 @@ import torch from torch import nn from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionMetadata) -from vllm.attention.backends.xformers import XFormersMetadata from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -18,6 +14,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -221,7 +218,6 @@ class MambaMixer2(CustomOp): head_dim: int = 64, rms_norm_eps: float = 1e-5, activation="silu", - chunk_size: int = 256, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -257,7 +253,6 @@ class MambaMixer2(CustomOp): self.ssm_state_size = ssm_state_size self.activation = activation - self.chunk_size = chunk_size self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_heads = num_heads @@ -388,25 +383,17 @@ class MambaMixer2(CustomOp): self, hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams, - sequence_idx: Optional[torch.Tensor] = None, + mamba2_metadata: Mamba2Metadata, ): + # mamba2_metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # are the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = get_forward_context().attn_metadata seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size - # detect if there are prefills - has_prefill = attn_metadata.num_prefills > 0 - - # - also need flags to indicate if there are initial states - # - currently we really only support the FlashAttention backend - has_initial_states = None - if (isinstance(attn_metadata, - (FlashAttentionMetadata, XFormersMetadata, - PlaceholderAttentionMetadata)) - and attn_metadata.context_lens_tensor is not None): - has_initial_states = attn_metadata.context_lens_tensor > 0 - # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) gate, hidden_states_B_C, dt = torch.split( @@ -423,7 +410,7 @@ class MambaMixer2(CustomOp): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if has_prefill: + if mamba2_metadata.has_prefill: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -439,7 +426,7 @@ class MambaMixer2(CustomOp): self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, - has_initial_state=has_initial_states, + has_initial_state=mamba2_metadata.has_initial_states, cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc).transpose( 0, 1)[:seq_len] @@ -467,16 +454,15 @@ class MambaMixer2(CustomOp): ) # 3. State Space Model sequence transformation - if has_prefill: - + if mamba2_metadata.has_prefill: initial_states = None - if has_initial_states is not None and torch.any( - has_initial_states): - zero_init_indices = mamba_cache_params.state_indices_tensor[ - ~has_initial_states] - mamba_cache_params.ssm_state[zero_init_indices] = 0 - initial_states = mamba_cache_params.ssm_state[ - mamba_cache_params.state_indices_tensor] + if (mamba2_metadata.has_initial_states is not None + and mamba2_metadata.prep_initial_states): + # making a copy of the states + initial_states = torch.where( + mamba2_metadata.has_initial_states[:, None, None, None], + mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor], 0) scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states.view(1, seq_len, self.num_heads // self.tp_size, @@ -485,11 +471,13 @@ class MambaMixer2(CustomOp): self.A, B.view(1, seq_len, self.n_groups // self.tp_size, -1), C.view(1, seq_len, self.n_groups // self.tp_size, -1), - chunk_size=self.chunk_size, + chunk_size=mamba2_metadata.chunk_size, D=self.D, z=None, dt_bias=self.dt_bias, - seq_idx=sequence_idx, + seq_idx=mamba2_metadata.seq_idx, + chunk_indices=mamba2_metadata.chunk_indices, + chunk_offsets=mamba2_metadata.chunk_offsets, cu_seqlens=attn_metadata.query_start_loc, initial_states=initial_states, return_varlen_states=True, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 7ef51112..005917f2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -5,8 +5,6 @@ # ruff: noqa: E501,SIM102 -import math - import torch import triton import triton.language as tl @@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel( (offs_out_n[None, :] < hdim)) -def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): - - # convert seq_idx to chunk indices and offsets - # - derive the cu_seqlens - _, cu_seqlens = torch.where(seq_idx.diff()) - cu_seqlens += 1 - - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) - chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) - - cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): - - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) - - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) - - # adjust inidces and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size - - return chunk_indices, chunk_offsets - - def _chunk_scan_fwd( cb, x, @@ -486,6 +450,8 @@ def _chunk_scan_fwd( D=None, z=None, seq_idx=None, + chunk_indices=None, + chunk_offsets=None, initial_states=None, ): batch, seqlen, nheads, headdim = x.shape @@ -502,7 +468,6 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) - chunk_indices, chunk_offsets = None, None if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) @@ -510,15 +475,19 @@ def _chunk_scan_fwd( # with initial states, we need to take care of how # seq_idx crosses the boundaries assert batch == 1, "chunk scan only supports initial states with batch 1" - assert initial_states.shape == (seq_idx[0].max() + 1, nheads, - headdim, dstate) if initial_states.shape[0] == 1: # no in this case no point to use initial states initial_states = None else: - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) + assert chunk_indices is not None and chunk_offsets is not None, \ + ( + "chunk_indices and chunk_offsets should have been set" + ) + else: + chunk_indices, chunk_offsets = None, None + else: + chunk_indices, chunk_offsets = None, None # Allocates output. out = torch.empty(batch, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 97cdb70b..3febd4cc 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x, dt_bias=None, initial_states=None, seq_idx=None, + chunk_indices=None, + chunk_offsets=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): @@ -96,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # ii) seq_idx and iii) is_cont_batched to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. # - this will ensure that states will be updated with the rightmost flushed seq_idx @@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x, D=D, z=z, seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, initial_states=initial_states, ) if cu_seqlens is None: @@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x, dt_bias=None, initial_states=None, seq_idx=None, + chunk_indices=None, + chunk_offsets=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), @@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index d8f87c11..219c5306 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -150,8 +150,6 @@ def _state_passing_fwd( # are used for continuous batching. In which case we # require seq_idx to be provided assert seq_idx is not None, "" - assert initial_states.shape == (seq_idx.max().item() + 1, nheads, - dim) else: # - this is the regular batching case, where initial # states are used are for each example of the batch. diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e5896f5f..dfb8f49c 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - chunk_size=config.mamba_chunk_size, quant_config=quant_config) self.feed_forward = BambaMLP(config, quant_config=quant_config) @@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, - sequence_idx: Optional[torch.Tensor] = None, + mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module): hidden_states, residual) hidden_states = self.mamba(hidden_states, mamba_cache_params, - sequence_idx) + mamba2_metadata) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -259,7 +260,7 @@ class BambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: BambaConfig = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -309,20 +310,13 @@ class BambaModel(nn.Module): inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - seq_idx = None attn_metadata = get_forward_context().attn_metadata - if attn_metadata.num_prefills > 0: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i - seq_idx.unsqueeze_(0) + + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + input_ids=input_ids, + attn_metadata=attn_metadata, + ) if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -352,7 +346,7 @@ class BambaModel(nn.Module): hidden_states=hidden_states, residual=residual, mamba_cache_params=layer_mamba_cache_params, - sequence_idx=seq_idx, + mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: @@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index da5cbddb..526dec46 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization.base_config import ( @@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module): head_dim=config.head_dim, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, - chunk_size=config.chunk_size, quant_config=quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, - sequence_idx: Optional[torch.Tensor], + mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) hidden_states = self.mixer(hidden_states, mamba_cache_params, - sequence_idx) + mamba2_metadata) return hidden_states, residual @@ -138,20 +139,13 @@ class Mamba2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - seq_idx = None attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - if attn_metadata.num_prefills > 0: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i - seq_idx.unsqueeze_(0) + + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + input_ids=input_ids, + attn_metadata=attn_metadata, + ) for i in range(len(self.layers)): layer = self.layers[i] @@ -162,7 +156,7 @@ class Mamba2Model(nn.Module): residual=residual, mamba_cache_params=mamba_cache_params.at_layer_idx( i - self.start_layer), - sequence_idx=seq_idx) + mamba2_metadata=mamba2_metadata) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index c5330203..ea21fffa 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -495,7 +497,6 @@ class Zamba2MambaDecoderLayer(nn.Module): head_dim=intermediate_size // config.n_mamba_heads, rms_norm_eps=config.rms_norm_eps, activation="silu", - chunk_size=config.chunk_size, quant_config=quant_config, ) @@ -507,7 +508,7 @@ class Zamba2MambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams, - sequence_idx: Optional[torch.Tensor] = None, + mamba2_metadata: Mamba2Metadata, transformer_hidden_states: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None, @@ -547,7 +548,7 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states = self.mamba( hidden_states, mamba_cache_params=mamba_cache_params, - sequence_idx=sequence_idx, + mamba2_metadata=mamba2_metadata, ) # residual connection after mamba @@ -594,8 +595,8 @@ class Zamba2HybridLayer(nn.Module): hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None, - sequence_idx: Optional[torch.Tensor] = None, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: """Forward pass through the hybrid layer. @@ -634,7 +635,7 @@ class Zamba2HybridLayer(nn.Module): hidden_states, transformer_hidden_states=transformer_hidden_states, mamba_cache_params=mamba_cache_params, - sequence_idx=sequence_idx, + mamba2_metadata=mamba2_metadata, ) return layer_outputs @@ -747,20 +748,13 @@ class Zamba2Model(nn.Module): inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = inputs_embeds - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - seq_idx = None attn_metadata = get_forward_context().attn_metadata - if attn_metadata.num_prefills > 0: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i - seq_idx.unsqueeze_(0) + + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + input_ids=input_ids, + attn_metadata=attn_metadata, + ) # Process through layers original_hidden_states = torch.clone(hidden_states) @@ -770,7 +764,7 @@ class Zamba2Model(nn.Module): original_hidden_states=original_hidden_states, positions=positions, mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), - sequence_idx=seq_idx, + mamba2_metadata=mamba2_metadata, ) hidden_states = layer_outputs