[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14778)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
270a5da495
commit
fe66b34728
@ -466,10 +466,17 @@ class MambaMixer2(CustomOp):
|
||||
if has_prefill:
|
||||
|
||||
initial_states = None
|
||||
if has_initial_states is not None and any(has_initial_states):
|
||||
for idx in mamba_cache_params.state_indices_tensor[
|
||||
~has_initial_states]:
|
||||
mamba_cache_params.ssm_state[idx].zero_()
|
||||
|
||||
if has_initial_states is not None and torch.any(
|
||||
has_initial_states):
|
||||
|
||||
# vectorized ssm_state zero init
|
||||
batched_zero_init_func = torch.vmap(
|
||||
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
|
||||
batched_zero_init_func(
|
||||
mamba_cache_params.
|
||||
state_indices_tensor[~has_initial_states].unsqueeze(
|
||||
dim=-1), )
|
||||
initial_states = mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor]
|
||||
|
||||
@ -493,10 +500,17 @@ class MambaMixer2(CustomOp):
|
||||
dt_limit=(0.0, float("inf")),
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
|
||||
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
|
||||
# vectorized ssm state update using vmap
|
||||
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
|
||||
# limitation which doesn't allow use of `item()`
|
||||
# Note: the lambda capture can happen where ssm_state is initialized
|
||||
# instead of here
|
||||
batched_copy = torch.vmap(
|
||||
lambda idx, source_state: mamba_cache_params.ssm_state[
|
||||
idx].copy_(source_state))
|
||||
batched_copy(
|
||||
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
|
||||
varlen_state)
|
||||
|
||||
# - reshape
|
||||
hidden_states = scan_output.view(seq_len, -1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user