[Bugfix] Fix tests/kernels/test_mamba_ssm_ssd.py (#16623)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
70e7ed841d
commit
dbb036cf61
@ -5,6 +5,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_seq_idx_to_chunk_indices_offsets)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.platforms import current_platform
|
||||
@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch,
|
||||
|
||||
# get the metadata
|
||||
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
|
||||
sed_idx = torch.zeros(cu_seqlens[-1],
|
||||
seq_idx = torch.zeros(cu_seqlens[-1],
|
||||
dtype=torch.int32,
|
||||
device=cu_seqlens.device)
|
||||
for i, (srt, end) in enumerate(zip(
|
||||
cu_seqlens,
|
||||
cu_seqlens[1:],
|
||||
)):
|
||||
sed_idx[srt:end] = i
|
||||
seq_idx[srt:end] = i
|
||||
|
||||
# for cont batch
|
||||
if IND_E is None:
|
||||
@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
|
||||
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
||||
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
|
||||
for Y_min, cu_seqlens, seq_idx, (A, dt, X, B,
|
||||
C) in generate_continous_batched_examples(
|
||||
cases, num_examples, seqlen,
|
||||
last_taken, exhausted, n_heads,
|
||||
d_head, itype):
|
||||
|
||||
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||
seq_idx, chunk_size)
|
||||
|
||||
Y, new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=sed_idx,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=states,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user