[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14778)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
270a5da495
commit
fe66b34728
@ -466,10 +466,17 @@ class MambaMixer2(CustomOp):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
|
|
||||||
initial_states = None
|
initial_states = None
|
||||||
if has_initial_states is not None and any(has_initial_states):
|
|
||||||
for idx in mamba_cache_params.state_indices_tensor[
|
if has_initial_states is not None and torch.any(
|
||||||
~has_initial_states]:
|
has_initial_states):
|
||||||
mamba_cache_params.ssm_state[idx].zero_()
|
|
||||||
|
# vectorized ssm_state zero init
|
||||||
|
batched_zero_init_func = torch.vmap(
|
||||||
|
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
|
||||||
|
batched_zero_init_func(
|
||||||
|
mamba_cache_params.
|
||||||
|
state_indices_tensor[~has_initial_states].unsqueeze(
|
||||||
|
dim=-1), )
|
||||||
initial_states = mamba_cache_params.ssm_state[
|
initial_states = mamba_cache_params.ssm_state[
|
||||||
mamba_cache_params.state_indices_tensor]
|
mamba_cache_params.state_indices_tensor]
|
||||||
|
|
||||||
@ -493,10 +500,17 @@ class MambaMixer2(CustomOp):
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
)
|
)
|
||||||
|
|
||||||
# update ssm states
|
# vectorized ssm state update using vmap
|
||||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
|
||||||
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
|
# limitation which doesn't allow use of `item()`
|
||||||
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
|
# Note: the lambda capture can happen where ssm_state is initialized
|
||||||
|
# instead of here
|
||||||
|
batched_copy = torch.vmap(
|
||||||
|
lambda idx, source_state: mamba_cache_params.ssm_state[
|
||||||
|
idx].copy_(source_state))
|
||||||
|
batched_copy(
|
||||||
|
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
|
||||||
|
varlen_state)
|
||||||
|
|
||||||
# - reshape
|
# - reshape
|
||||||
hidden_states = scan_output.view(seq_len, -1)
|
hidden_states = scan_output.view(seq_len, -1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user