[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14778)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang 2025-03-14 16:36:18 -04:00 committed by GitHub
parent 270a5da495
commit fe66b34728
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -466,10 +466,17 @@ class MambaMixer2(CustomOp):
if has_prefill: if has_prefill:
initial_states = None initial_states = None
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[ if has_initial_states is not None and torch.any(
~has_initial_states]: has_initial_states):
mamba_cache_params.ssm_state[idx].zero_()
# vectorized ssm_state zero init
batched_zero_init_func = torch.vmap(
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
batched_zero_init_func(
mamba_cache_params.
state_indices_tensor[~has_initial_states].unsqueeze(
dim=-1), )
initial_states = mamba_cache_params.ssm_state[ initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] mamba_cache_params.state_indices_tensor]
@ -493,10 +500,17 @@ class MambaMixer2(CustomOp):
dt_limit=(0.0, float("inf")), dt_limit=(0.0, float("inf")),
) )
# update ssm states # vectorized ssm state update using vmap
# - varlen state is a (batch, nheads, headdim, dstate) tensor # the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
for i, idx in enumerate(mamba_cache_params.state_indices_tensor): # limitation which doesn't allow use of `item()`
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) # Note: the lambda capture can happen where ssm_state is initialized
# instead of here
batched_copy = torch.vmap(
lambda idx, source_state: mamba_cache_params.ssm_state[
idx].copy_(source_state))
batched_copy(
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
varlen_state)
# - reshape # - reshape
hidden_states = scan_output.view(seq_len, -1) hidden_states = scan_output.view(seq_len, -1)