[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
parent
5fbab20e02
commit
daefed052c
109
vllm/model_executor/layers/mamba/mamba2_metadata.py
Normal file
109
vllm/model_executor/layers/mamba/mamba2_metadata.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.attention.backends.placeholder_attn import (
|
||||||
|
PlaceholderAttentionMetadata)
|
||||||
|
from vllm.attention.backends.xformers import XFormersMetadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mamba2Metadata:
|
||||||
|
has_prefill: bool
|
||||||
|
|
||||||
|
has_initial_states: torch.Tensor
|
||||||
|
prep_initial_states: bool
|
||||||
|
|
||||||
|
chunk_size: int
|
||||||
|
seq_idx: torch.Tensor
|
||||||
|
chunk_indices: torch.Tensor
|
||||||
|
chunk_offsets: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
|
||||||
|
|
||||||
|
# convert seq_idx to chunk indices and offsets
|
||||||
|
# - derive the cu_seqlens
|
||||||
|
_, cu_seqlens = torch.where(seq_idx.diff())
|
||||||
|
cu_seqlens += 1
|
||||||
|
|
||||||
|
# outputs will have length expansion of chunks that do not divide
|
||||||
|
# chunk_size
|
||||||
|
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
|
||||||
|
> 0).sum()
|
||||||
|
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
|
||||||
|
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
|
||||||
|
|
||||||
|
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
|
||||||
|
p = 0 # num of insertions
|
||||||
|
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||||
|
|
||||||
|
# if does not divide chunk_size, then there is one chunk insertion
|
||||||
|
p += (s % chunk_size > 0)
|
||||||
|
|
||||||
|
# get the dimensions
|
||||||
|
# - the + 1 for _e is to shift the boundary by one chunk
|
||||||
|
# - this shifting is not needed if chunk_size divides e
|
||||||
|
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||||
|
> 0)
|
||||||
|
|
||||||
|
# adjust inidces and offsets
|
||||||
|
chunk_indices[_s:_e] -= p
|
||||||
|
chunk_offsets[_s] = s % chunk_size
|
||||||
|
|
||||||
|
return chunk_indices, chunk_offsets
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_mamba2_metadata(
|
||||||
|
chunk_size: int,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> Mamba2Metadata:
|
||||||
|
|
||||||
|
# Need flags to indicate if there are initial states
|
||||||
|
# currently we really only support the FlashAttention backend
|
||||||
|
has_initial_states = None
|
||||||
|
prep_initial_states = False
|
||||||
|
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
|
||||||
|
PlaceholderAttentionMetadata))
|
||||||
|
and attn_metadata.context_lens_tensor is not None):
|
||||||
|
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||||
|
# precompute flag to avoid device syncs later in mamba2 forwards
|
||||||
|
prep_initial_states = torch.any(has_initial_states).item()
|
||||||
|
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
|
|
||||||
|
seq_idx = None
|
||||||
|
chunk_indices, chunk_offsets = None, None
|
||||||
|
if has_prefill:
|
||||||
|
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||||
|
for i, (srt, end) in enumerate(
|
||||||
|
zip(
|
||||||
|
attn_metadata.query_start_loc,
|
||||||
|
attn_metadata.query_start_loc[1:],
|
||||||
|
)):
|
||||||
|
seq_idx[srt:end] = i
|
||||||
|
seq_idx.unsqueeze_(0)
|
||||||
|
|
||||||
|
# compute metadata for chunked prefill.
|
||||||
|
# actually this is only needed if there are initial states,
|
||||||
|
# but this is determinable only from attention metadata yet
|
||||||
|
# unavailable from the top-level model forward. Rather than
|
||||||
|
# complicating things to extract said metadata, we simply just
|
||||||
|
# compute them once at the top level model forward and reuse
|
||||||
|
# them in mamba layers. If not needed, they will be ignored
|
||||||
|
# inside mamba kernels.
|
||||||
|
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||||
|
seq_idx, chunk_size)
|
||||||
|
|
||||||
|
return Mamba2Metadata(has_prefill=has_prefill,
|
||||||
|
has_initial_states=has_initial_states,
|
||||||
|
prep_initial_states=prep_initial_states,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
seq_idx=seq_idx,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets)
|
@ -6,10 +6,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
|
||||||
PlaceholderAttentionMetadata)
|
|
||||||
from vllm.attention.backends.xformers import XFormersMetadata
|
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
@ -18,6 +14,7 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -221,7 +218,6 @@ class MambaMixer2(CustomOp):
|
|||||||
head_dim: int = 64,
|
head_dim: int = 64,
|
||||||
rms_norm_eps: float = 1e-5,
|
rms_norm_eps: float = 1e-5,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
chunk_size: int = 256,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -257,7 +253,6 @@ class MambaMixer2(CustomOp):
|
|||||||
self.ssm_state_size = ssm_state_size
|
self.ssm_state_size = ssm_state_size
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -388,25 +383,17 @@ class MambaMixer2(CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor] = None,
|
mamba2_metadata: Mamba2Metadata,
|
||||||
):
|
):
|
||||||
|
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||||
|
# kernels to operate in continuous batching and in chunked prefill
|
||||||
|
# modes; they are computed at top-level model forward since they
|
||||||
|
# are the same and reused for all mamba layers in the same iteration
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
seq_len, _ = hidden_states.shape
|
seq_len, _ = hidden_states.shape
|
||||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||||
|
|
||||||
# detect if there are prefills
|
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
|
||||||
|
|
||||||
# - also need flags to indicate if there are initial states
|
|
||||||
# - currently we really only support the FlashAttention backend
|
|
||||||
has_initial_states = None
|
|
||||||
if (isinstance(attn_metadata,
|
|
||||||
(FlashAttentionMetadata, XFormersMetadata,
|
|
||||||
PlaceholderAttentionMetadata))
|
|
||||||
and attn_metadata.context_lens_tensor is not None):
|
|
||||||
has_initial_states = attn_metadata.context_lens_tensor > 0
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
gate, hidden_states_B_C, dt = torch.split(
|
gate, hidden_states_B_C, dt = torch.split(
|
||||||
@ -423,7 +410,7 @@ class MambaMixer2(CustomOp):
|
|||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
self.conv1d.weight.size(2))
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
if has_prefill:
|
if mamba2_metadata.has_prefill:
|
||||||
# |---------- N-1 iteration --------|
|
# |---------- N-1 iteration --------|
|
||||||
# |---------------- N iteration ---------------------|
|
# |---------------- N iteration ---------------------|
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
# |- tokenA -|......................|-- newTokens ---|
|
||||||
@ -439,7 +426,7 @@ class MambaMixer2(CustomOp):
|
|||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=mamba_cache_params.conv_state,
|
conv_states=mamba_cache_params.conv_state,
|
||||||
has_initial_state=has_initial_states,
|
has_initial_state=mamba2_metadata.has_initial_states,
|
||||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||||
query_start_loc=attn_metadata.query_start_loc).transpose(
|
query_start_loc=attn_metadata.query_start_loc).transpose(
|
||||||
0, 1)[:seq_len]
|
0, 1)[:seq_len]
|
||||||
@ -467,16 +454,15 @@ class MambaMixer2(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
if has_prefill:
|
if mamba2_metadata.has_prefill:
|
||||||
|
|
||||||
initial_states = None
|
initial_states = None
|
||||||
if has_initial_states is not None and torch.any(
|
if (mamba2_metadata.has_initial_states is not None
|
||||||
has_initial_states):
|
and mamba2_metadata.prep_initial_states):
|
||||||
zero_init_indices = mamba_cache_params.state_indices_tensor[
|
# making a copy of the states
|
||||||
~has_initial_states]
|
initial_states = torch.where(
|
||||||
mamba_cache_params.ssm_state[zero_init_indices] = 0
|
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||||
initial_states = mamba_cache_params.ssm_state[
|
mamba_cache_params.ssm_state[
|
||||||
mamba_cache_params.state_indices_tensor]
|
mamba_cache_params.state_indices_tensor], 0)
|
||||||
|
|
||||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||||
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
|
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
|
||||||
@ -485,11 +471,13 @@ class MambaMixer2(CustomOp):
|
|||||||
self.A,
|
self.A,
|
||||||
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
||||||
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=mamba2_metadata.chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
z=None,
|
z=None,
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
seq_idx=sequence_idx,
|
seq_idx=mamba2_metadata.seq_idx,
|
||||||
|
chunk_indices=mamba2_metadata.chunk_indices,
|
||||||
|
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||||
cu_seqlens=attn_metadata.query_start_loc,
|
cu_seqlens=attn_metadata.query_start_loc,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
return_varlen_states=True,
|
return_varlen_states=True,
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
|
|
||||||
# ruff: noqa: E501,SIM102
|
# ruff: noqa: E501,SIM102
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
|
|||||||
(offs_out_n[None, :] < hdim))
|
(offs_out_n[None, :] < hdim))
|
||||||
|
|
||||||
|
|
||||||
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
|
|
||||||
|
|
||||||
# convert seq_idx to chunk indices and offsets
|
|
||||||
# - derive the cu_seqlens
|
|
||||||
_, cu_seqlens = torch.where(seq_idx.diff())
|
|
||||||
cu_seqlens += 1
|
|
||||||
|
|
||||||
# outputs will have length expansion of chunks that do not divide
|
|
||||||
# chunk_size
|
|
||||||
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
|
|
||||||
> 0).sum()
|
|
||||||
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
|
|
||||||
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
|
|
||||||
|
|
||||||
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
|
|
||||||
p = 0 # num of insertions
|
|
||||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
|
||||||
|
|
||||||
# if does not divide chunk_size, then there is one chunk insertion
|
|
||||||
p += (s % chunk_size > 0)
|
|
||||||
|
|
||||||
# get the dimensions
|
|
||||||
# - the + 1 for _e is to shift the boundary by one chunk
|
|
||||||
# - this shifting is not needed if chunk_size divides e
|
|
||||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
|
||||||
> 0)
|
|
||||||
|
|
||||||
# adjust inidces and offsets
|
|
||||||
chunk_indices[_s:_e] -= p
|
|
||||||
chunk_offsets[_s] = s % chunk_size
|
|
||||||
|
|
||||||
return chunk_indices, chunk_offsets
|
|
||||||
|
|
||||||
|
|
||||||
def _chunk_scan_fwd(
|
def _chunk_scan_fwd(
|
||||||
cb,
|
cb,
|
||||||
x,
|
x,
|
||||||
@ -486,6 +450,8 @@ def _chunk_scan_fwd(
|
|||||||
D=None,
|
D=None,
|
||||||
z=None,
|
z=None,
|
||||||
seq_idx=None,
|
seq_idx=None,
|
||||||
|
chunk_indices=None,
|
||||||
|
chunk_offsets=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
):
|
):
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
batch, seqlen, nheads, headdim = x.shape
|
||||||
@ -502,7 +468,6 @@ def _chunk_scan_fwd(
|
|||||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||||
|
|
||||||
chunk_indices, chunk_offsets = None, None
|
|
||||||
if seq_idx is not None:
|
if seq_idx is not None:
|
||||||
assert seq_idx.shape == (batch, seqlen)
|
assert seq_idx.shape == (batch, seqlen)
|
||||||
|
|
||||||
@ -510,15 +475,19 @@ def _chunk_scan_fwd(
|
|||||||
# with initial states, we need to take care of how
|
# with initial states, we need to take care of how
|
||||||
# seq_idx crosses the boundaries
|
# seq_idx crosses the boundaries
|
||||||
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
||||||
assert initial_states.shape == (seq_idx[0].max() + 1, nheads,
|
|
||||||
headdim, dstate)
|
|
||||||
|
|
||||||
if initial_states.shape[0] == 1:
|
if initial_states.shape[0] == 1:
|
||||||
# no in this case no point to use initial states
|
# no in this case no point to use initial states
|
||||||
initial_states = None
|
initial_states = None
|
||||||
else:
|
else:
|
||||||
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||||
seq_idx, chunk_size)
|
(
|
||||||
|
"chunk_indices and chunk_offsets should have been set"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_indices, chunk_offsets = None, None
|
||||||
|
else:
|
||||||
|
chunk_indices, chunk_offsets = None, None
|
||||||
|
|
||||||
# Allocates output.
|
# Allocates output.
|
||||||
out = torch.empty(batch,
|
out = torch.empty(batch,
|
||||||
|
@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
seq_idx=None,
|
seq_idx=None,
|
||||||
|
chunk_indices=None,
|
||||||
|
chunk_offsets=None,
|
||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf"))):
|
dt_limit=(0.0, float("inf"))):
|
||||||
@ -96,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||||
# (middle term of factorization of off-diag blocks; A terms)
|
# (middle term of factorization of off-diag blocks; A terms)
|
||||||
# - for handling chunked prefill, this requires i) initial_states
|
# - for handling chunked prefill, this requires i) initial_states
|
||||||
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
|
# ii) seq_idx and iii) is_cont_batched to be all specified.
|
||||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||||
@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
D=D,
|
D=D,
|
||||||
z=z,
|
z=z,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
)
|
)
|
||||||
if cu_seqlens is None:
|
if cu_seqlens is None:
|
||||||
@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x,
|
|||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
initial_states=None,
|
initial_states=None,
|
||||||
seq_idx=None,
|
seq_idx=None,
|
||||||
|
chunk_indices=None,
|
||||||
|
chunk_offsets=None,
|
||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x,
|
|||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit)
|
dt_limit=dt_limit)
|
||||||
|
@ -150,8 +150,6 @@ def _state_passing_fwd(
|
|||||||
# are used for continuous batching. In which case we
|
# are used for continuous batching. In which case we
|
||||||
# require seq_idx to be provided
|
# require seq_idx to be provided
|
||||||
assert seq_idx is not None, ""
|
assert seq_idx is not None, ""
|
||||||
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
|
|
||||||
dim)
|
|
||||||
else:
|
else:
|
||||||
# - this is the regular batching case, where initial
|
# - this is the regular batching case, where initial
|
||||||
# states are used are for each example of the batch.
|
# states are used are for each example of the batch.
|
||||||
|
@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
MambaMixer2, extra_groups_for_head_shards)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
head_dim=config.mamba_d_head,
|
head_dim=config.mamba_d_head,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
chunk_size=config.mamba_chunk_size,
|
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
||||||
@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor] = None,
|
mamba2_metadata: Mamba2Metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||||
sequence_idx)
|
mamba2_metadata)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
@ -259,7 +260,7 @@ class BambaModel(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config: BambaConfig = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
@ -309,20 +310,13 @@ class BambaModel(nn.Module):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# pass a sequence index tensor, that is required for
|
|
||||||
# proper continuous batching computation including
|
|
||||||
# chunked prefill
|
|
||||||
seq_idx = None
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.num_prefills > 0:
|
|
||||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
for i, (srt, end) in enumerate(
|
chunk_size=self.config.mamba_chunk_size,
|
||||||
zip(
|
input_ids=input_ids,
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata=attn_metadata,
|
||||||
attn_metadata.query_start_loc[1:],
|
)
|
||||||
)):
|
|
||||||
seq_idx[srt:end] = i
|
|
||||||
seq_idx.unsqueeze_(0)
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
@ -352,7 +346,7 @@ class BambaModel(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
mamba_cache_params=layer_mamba_cache_params,
|
||||||
sequence_idx=seq_idx,
|
mamba2_metadata=mamba2_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
MambaMixer2, extra_groups_for_head_shards)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
head_dim=config.head_dim,
|
head_dim=config.head_dim,
|
||||||
rms_norm_eps=config.layer_norm_epsilon,
|
rms_norm_eps=config.layer_norm_epsilon,
|
||||||
activation=config.hidden_act,
|
activation=config.hidden_act,
|
||||||
chunk_size=config.chunk_size,
|
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor],
|
mamba2_metadata: Mamba2Metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||||
sequence_idx)
|
mamba2_metadata)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
# pass a sequence index tensor, that is required for
|
|
||||||
# proper continuous batching computation including
|
|
||||||
# chunked prefill
|
|
||||||
seq_idx = None
|
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.num_prefills > 0:
|
|
||||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
for i, (srt, end) in enumerate(
|
chunk_size=self.config.chunk_size,
|
||||||
zip(
|
input_ids=input_ids,
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata=attn_metadata,
|
||||||
attn_metadata.query_start_loc[1:],
|
)
|
||||||
)):
|
|
||||||
seq_idx[srt:end] = i
|
|
||||||
seq_idx.unsqueeze_(0)
|
|
||||||
|
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
|
|||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||||
i - self.start_layer),
|
i - self.start_layer),
|
||||||
sequence_idx=seq_idx)
|
mamba2_metadata=mamba2_metadata)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
|
@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
MambaMixer2, extra_groups_for_head_shards)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -495,7 +497,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
head_dim=intermediate_size // config.n_mamba_heads,
|
head_dim=intermediate_size // config.n_mamba_heads,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
chunk_size=config.chunk_size,
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -507,7 +508,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor] = None,
|
mamba2_metadata: Mamba2Metadata,
|
||||||
transformer_hidden_states: Optional[torch.Tensor] = None,
|
transformer_hidden_states: Optional[torch.Tensor] = None,
|
||||||
positions: Optional[torch.Tensor] = None,
|
positions: Optional[torch.Tensor] = None,
|
||||||
original_hidden_states: Optional[torch.Tensor] = None,
|
original_hidden_states: Optional[torch.Tensor] = None,
|
||||||
@ -547,7 +548,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.mamba(
|
hidden_states = self.mamba(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
mamba_cache_params=mamba_cache_params,
|
mamba_cache_params=mamba_cache_params,
|
||||||
sequence_idx=sequence_idx,
|
mamba2_metadata=mamba2_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# residual connection after mamba
|
# residual connection after mamba
|
||||||
@ -594,8 +595,8 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
original_hidden_states: torch.Tensor,
|
original_hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor] = None,
|
mamba2_metadata: Mamba2Metadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass through the hybrid layer.
|
"""Forward pass through the hybrid layer.
|
||||||
|
|
||||||
@ -634,7 +635,7 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
transformer_hidden_states=transformer_hidden_states,
|
transformer_hidden_states=transformer_hidden_states,
|
||||||
mamba_cache_params=mamba_cache_params,
|
mamba_cache_params=mamba_cache_params,
|
||||||
sequence_idx=sequence_idx,
|
mamba2_metadata=mamba2_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return layer_outputs
|
return layer_outputs
|
||||||
@ -747,20 +748,13 @@ class Zamba2Model(nn.Module):
|
|||||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# pass a sequence index tensor, that is required for
|
|
||||||
# proper continuous batching computation including
|
|
||||||
# chunked prefill
|
|
||||||
seq_idx = None
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.num_prefills > 0:
|
|
||||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
for i, (srt, end) in enumerate(
|
chunk_size=self.config.chunk_size,
|
||||||
zip(
|
input_ids=input_ids,
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata=attn_metadata,
|
||||||
attn_metadata.query_start_loc[1:],
|
)
|
||||||
)):
|
|
||||||
seq_idx[srt:end] = i
|
|
||||||
seq_idx.unsqueeze_(0)
|
|
||||||
|
|
||||||
# Process through layers
|
# Process through layers
|
||||||
original_hidden_states = torch.clone(hidden_states)
|
original_hidden_states = torch.clone(hidden_states)
|
||||||
@ -770,7 +764,7 @@ class Zamba2Model(nn.Module):
|
|||||||
original_hidden_states=original_hidden_states,
|
original_hidden_states=original_hidden_states,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
|
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
|
||||||
sequence_idx=seq_idx,
|
mamba2_metadata=mamba2_metadata,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user