[Bugfix] Fix tests/kernels/test_mamba_ssm_ssd.py (#16623)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
70e7ed841d
commit
dbb036cf61
@ -5,6 +5,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
|
_seq_idx_to_chunk_indices_offsets)
|
||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch,
|
|||||||
|
|
||||||
# get the metadata
|
# get the metadata
|
||||||
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
|
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
|
||||||
sed_idx = torch.zeros(cu_seqlens[-1],
|
seq_idx = torch.zeros(cu_seqlens[-1],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=cu_seqlens.device)
|
device=cu_seqlens.device)
|
||||||
for i, (srt, end) in enumerate(zip(
|
for i, (srt, end) in enumerate(zip(
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens[1:],
|
cu_seqlens[1:],
|
||||||
)):
|
)):
|
||||||
sed_idx[srt:end] = i
|
seq_idx[srt:end] = i
|
||||||
|
|
||||||
# for cont batch
|
# for cont batch
|
||||||
if IND_E is None:
|
if IND_E is None:
|
||||||
@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
|
|||||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||||
|
|
||||||
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
|
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
|
||||||
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype",
|
@pytest.mark.parametrize("itype",
|
||||||
@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||||
|
|
||||||
states = None
|
states = None
|
||||||
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
|
for Y_min, cu_seqlens, seq_idx, (A, dt, X, B,
|
||||||
C) in generate_continous_batched_examples(
|
C) in generate_continous_batched_examples(
|
||||||
cases, num_examples, seqlen,
|
cases, num_examples, seqlen,
|
||||||
last_taken, exhausted, n_heads,
|
last_taken, exhausted, n_heads,
|
||||||
d_head, itype):
|
d_head, itype):
|
||||||
|
|
||||||
|
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||||
|
seq_idx, chunk_size)
|
||||||
|
|
||||||
Y, new_states = mamba_chunk_scan_combined(
|
Y, new_states = mamba_chunk_scan_combined(
|
||||||
X,
|
X,
|
||||||
dt,
|
dt,
|
||||||
@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
D=None,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
seq_idx=sed_idx,
|
seq_idx=seq_idx,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
return_varlen_states=True,
|
return_varlen_states=True,
|
||||||
initial_states=states,
|
initial_states=states,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user