[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14778)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang 2025-03-14 16:36:18 -04:00 committed by GitHub
parent 270a5da495
commit fe66b34728
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -466,10 +466,17 @@ class MambaMixer2(CustomOp):
if has_prefill:
initial_states = None
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[
~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
if has_initial_states is not None and torch.any(
has_initial_states):
# vectorized ssm_state zero init
batched_zero_init_func = torch.vmap(
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
batched_zero_init_func(
mamba_cache_params.
state_indices_tensor[~has_initial_states].unsqueeze(
dim=-1), )
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
@ -493,10 +500,17 @@ class MambaMixer2(CustomOp):
dt_limit=(0.0, float("inf")),
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
# vectorized ssm state update using vmap
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
# limitation which doesn't allow use of `item()`
# Note: the lambda capture can happen where ssm_state is initialized
# instead of here
batched_copy = torch.vmap(
lambda idx, source_state: mamba_cache_params.ssm_state[
idx].copy_(source_state))
batched_copy(
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
varlen_state)
# - reshape
hidden_states = scan_output.view(seq_len, -1)