[Bugfix] Fix tests/kernels/test_mamba_ssm_ssd.py (#16623)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-04-15 01:35:38 -04:00 committed by GitHub
parent 70e7ed841d
commit dbb036cf61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch,
# get the metadata # get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
sed_idx = torch.zeros(cu_seqlens[-1], seq_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32, dtype=torch.int32,
device=cu_seqlens.device) device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip( for i, (srt, end) in enumerate(zip(
cu_seqlens, cu_seqlens,
cu_seqlens[1:], cu_seqlens[1:],
)): )):
sed_idx[srt:end] = i seq_idx[srt:end] = i
# for cont batch # for cont batch
if IND_E is None: if IND_E is None:
@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, for Y_min, cu_seqlens, seq_idx, (A, dt, X, B,
C) in generate_continous_batched_examples( C) in generate_continous_batched_examples(
cases, num_examples, seqlen, cases, num_examples, seqlen,
last_taken, exhausted, n_heads, last_taken, exhausted, n_heads,
d_head, itype): d_head, itype):
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
Y, new_states = mamba_chunk_scan_combined( Y, new_states = mamba_chunk_scan_combined(
X, X,
dt, dt,
@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
chunk_size, chunk_size,
D=None, D=None,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
seq_idx=sed_idx, seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True, return_varlen_states=True,
initial_states=states, initial_states=states,
) )