[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Chih-Chieh Yang 2025-04-10 15:07:07 -04:00 committed by GitHub
parent 5fbab20e02
commit daefed052c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 186 additions and 132 deletions

View File

@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
@dataclass
class Mamba2Metadata:
has_prefill: bool
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def prepare_mamba2_metadata(
chunk_size: int,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# precompute flag to avoid device syncs later in mamba2 forwards
prep_initial_states = torch.any(has_initial_states).item()
has_prefill = attn_metadata.num_prefills > 0
seq_idx = None
chunk_indices, chunk_offsets = None, None
if has_prefill:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
# compute metadata for chunked prefill.
# actually this is only needed if there are initial states,
# but this is determinable only from attention metadata yet
# unavailable from the top-level model forward. Rather than
# complicating things to extract said metadata, we simply just
# compute them once at the top level model forward and reuse
# them in mamba layers. If not needed, they will be ignored
# inside mamba kernels.
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
return Mamba2Metadata(has_prefill=has_prefill,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets)

View File

@ -6,10 +6,6 @@ import torch
from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@ -18,6 +14,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -221,7 +218,6 @@ class MambaMixer2(CustomOp):
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation="silu",
chunk_size: int = 256,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@ -257,7 +253,6 @@ class MambaMixer2(CustomOp):
self.ssm_state_size = ssm_state_size
self.activation = activation
self.chunk_size = chunk_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
@ -388,25 +383,17 @@ class MambaMixer2(CustomOp):
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
):
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
# detect if there are prefills
has_prefill = attn_metadata.num_prefills > 0
# - also need flags to indicate if there are initial states
# - currently we really only support the FlashAttention backend
has_initial_states = None
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
gate, hidden_states_B_C, dt = torch.split(
@ -423,7 +410,7 @@ class MambaMixer2(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if has_prefill:
if mamba2_metadata.has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
@ -439,7 +426,7 @@ class MambaMixer2(CustomOp):
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=has_initial_states,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]
@ -467,16 +454,15 @@ class MambaMixer2(CustomOp):
)
# 3. State Space Model sequence transformation
if has_prefill:
if mamba2_metadata.has_prefill:
initial_states = None
if has_initial_states is not None and torch.any(
has_initial_states):
zero_init_indices = mamba_cache_params.state_indices_tensor[
~has_initial_states]
mamba_cache_params.ssm_state[zero_init_indices] = 0
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
@ -485,11 +471,13 @@ class MambaMixer2(CustomOp):
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
chunk_size=self.chunk_size,
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=sequence_idx,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc,
initial_states=initial_states,
return_varlen_states=True,

View File

@ -5,8 +5,6 @@
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
(offs_out_n[None, :] < hdim))
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def _chunk_scan_fwd(
cb,
x,
@ -486,6 +450,8 @@ def _chunk_scan_fwd(
D=None,
z=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
initial_states=None,
):
batch, seqlen, nheads, headdim = x.shape
@ -502,7 +468,6 @@ def _chunk_scan_fwd(
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
chunk_indices, chunk_offsets = None, None
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
@ -510,15 +475,19 @@ def _chunk_scan_fwd(
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert batch == 1, "chunk scan only supports initial states with batch 1"
assert initial_states.shape == (seq_idx[0].max() + 1, nheads,
headdim, dstate)
if initial_states.shape[0] == 1:
# no in this case no point to use initial states
initial_states = None
else:
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
assert chunk_indices is not None and chunk_offsets is not None, \
(
"chunk_indices and chunk_offsets should have been set"
)
else:
chunk_indices, chunk_offsets = None, None
else:
chunk_indices, chunk_offsets = None, None
# Allocates output.
out = torch.empty(batch,

View File

@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
@ -96,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
# ii) seq_idx and iii) is_cont_batched to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - this will ensure that states will be updated with the rightmost flushed seq_idx
@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
)
if cu_seqlens is None:
@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)

View File

@ -150,8 +150,6 @@ def _state_passing_fwd(
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
dim)
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.

View File

@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
chunk_size=config.mamba_chunk_size,
quant_config=quant_config)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual)
hidden_states = self.mamba(hidden_states, mamba_cache_params,
sequence_idx)
mamba2_metadata)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
@ -259,7 +260,7 @@ class BambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config: BambaConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -309,20 +310,13 @@ class BambaModel(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@ -352,7 +346,7 @@ class BambaModel(nn.Module):
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
sequence_idx=seq_idx,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights)

View File

@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import (
@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
chunk_size=config.chunk_size,
quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor],
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx)
mamba2_metadata)
return hidden_states, residual
@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
for i in range(len(self.layers)):
layer = self.layers[i]
@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
sequence_idx=seq_idx)
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({

View File

@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -495,7 +497,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
head_dim=intermediate_size // config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
chunk_size=config.chunk_size,
quant_config=quant_config,
)
@ -507,7 +508,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
transformer_hidden_states: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
original_hidden_states: Optional[torch.Tensor] = None,
@ -547,7 +548,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.mamba(
hidden_states,
mamba_cache_params=mamba_cache_params,
sequence_idx=sequence_idx,
mamba2_metadata=mamba2_metadata,
)
# residual connection after mamba
@ -594,8 +595,8 @@ class Zamba2HybridLayer(nn.Module):
hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None,
sequence_idx: Optional[torch.Tensor] = None,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
"""Forward pass through the hybrid layer.
@ -634,7 +635,7 @@ class Zamba2HybridLayer(nn.Module):
hidden_states,
transformer_hidden_states=transformer_hidden_states,
mamba_cache_params=mamba_cache_params,
sequence_idx=sequence_idx,
mamba2_metadata=mamba2_metadata,
)
return layer_outputs
@ -747,20 +748,13 @@ class Zamba2Model(nn.Module):
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
# Process through layers
original_hidden_states = torch.clone(hidden_states)
@ -770,7 +764,7 @@ class Zamba2Model(nn.Module):
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
sequence_idx=seq_idx,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs