From fb60ae9b91a4b3e1aed4a6e826895fe3c5a13c10 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 17 Oct 2024 00:12:43 +0800 Subject: [PATCH] [Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189) --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 37 ++-- csrc/mamba/causal_conv1d/causal_conv1d.h | 1 + csrc/mamba/mamba_ssm/selective_scan.h | 1 + csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 24 ++- csrc/ops.h | 28 +-- csrc/torch_bindings.cpp | 9 +- tests/kernels/test_causal_conv1d.py | 189 +++++++++--------- tests/kernels/test_mamba_ssm.py | 124 ++++++++---- .../decoder_only/language/test_jamba.py | 25 +++ vllm/_custom_ops.py | 65 +++--- .../layers/mamba/ops/causal_conv1d.py | 53 +++-- .../layers/mamba/ops/mamba_ssm.py | 70 ++++--- vllm/model_executor/models/jamba.py | 71 +++---- vllm/model_executor/models/mamba.py | 53 ++--- vllm/model_executor/models/mamba_cache.py | 186 ++++++----------- 15 files changed, 504 insertions(+), 432 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 30831efd..3a464c5f 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor out, const c10::optional& bias, bool silu_activation, + int64_t pad_slot_id, const c10::optional& query_start_loc = std::nullopt, const c10::optional& cache_indices = std::nullopt, const c10::optional& has_initial_state = std::nullopt) { @@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, params.dim = dim; params.seqlen = seqlen; params.width = width; + params.pad_slot_id = pad_slot_id; params.silu_activation = silu_activation; @@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, } -at::Tensor -causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, +void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, const c10::optional &conv_states, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - bool silu_activation) { + bool silu_activation, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, CHECK_SHAPE(cache_indices_, batch_size); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, + pad_slot_id, query_start_loc, cache_indices, has_initial_state @@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); }); - return out; } -at::Tensor -causal_conv1d_update(const at::Tensor &x, +void causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, bool silu_activation, const c10::optional &cache_seqlens_, - const c10::optional &conv_state_indices_) { + const c10::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x, CHECK_SHAPE(bias, dim); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, - silu_activation); + silu_activation, + pad_slot_id); params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_len = conv_state_len; // All stride are in elements, not bytes. @@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); }); - return out; } template @@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; @@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; + // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early + if (conv_state_batch_coord == params.pad_slot_id){ + return; + } input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 49e37ee4..e26684a2 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -13,6 +13,7 @@ struct ConvParamsBase { using index_t = uint32_t; int batch, dim, seqlen, width; + int64_t pad_slot_id; bool silu_activation; index_t x_batch_stride; diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 580d0b2e..563d2fe4 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -21,6 +21,7 @@ struct SSMParamsBase { int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; + int64_t pad_slot_id; bool delta_softplus; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 6b225b41..71624696 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride @@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const size_t seqlen, const size_t dstate, const size_t n_groups, - const size_t n_chunks, const bool is_variable_B, const bool is_variable_C, // device pointers @@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - bool varlen) { + bool varlen, + int64_t pad_slot_id) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.seqlen = seqlen; params.dstate = dstate; params.n_groups = n_groups; - params.n_chunks = n_chunks; params.dim_ngroups_ratio = dim / n_groups; + params.pad_slot_id = pad_slot_id; params.delta_softplus = delta_softplus; @@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - const torch::Tensor &ssm_states) { + const torch::Tensor &ssm_states, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, out_z = z; - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.stride(-1) == 1); - CHECK_SHAPE(ssm_states, batch_size, dim, dstate); SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_, delta_bias_, @@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, query_start_loc, cache_indices, has_initial_state, - varlen + varlen, + pad_slot_id ); diff --git a/csrc/ops.h b/csrc/ops.h index fce545f9..c10c34e0 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - const torch::Tensor& ssm_states); + const torch::Tensor& ssm_states, int64_t pad_slot_id); -at::Tensor causal_conv1d_update( - const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias_, bool silu_activation, - const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_); +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation, + const c10::optional& cache_seqlens_, + const c10::optional& conv_state_indices_, + int64_t pad_slot_id); -at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& conv_states, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, - bool silu_activation); +void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& conv_states, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a..d69c4e5a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "Tensor! ssm_states) -> ()"); + "Tensor! ssm_states," + "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "bool silu_activation," "Tensor? cache_seqlens_," - "Tensor? conv_state_indices) -> Tensor"); + "Tensor? conv_state_indices," + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( @@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "bool silu_activation) -> Tensor"); + "bool silu_activation," + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 069020a5..277d7e49 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -114,16 +115,15 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_opcheck_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn( x = x.contiguous() bias = bias.contiguous() if bias is not None else None - opcheck(torch.ops._C.causal_conv1d_fwd, ( - x, - weight, - bias, - conv_states, - cu_seq_len, - cache_indices, - has_initial_state, - activation in ["silu", "swish"], - )) + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, conv_states, cu_seq_len, cache_indices, + has_initial_state, activation in ["silu", "swish"], pad_slot_id)) @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, seed_everything(0) batch = 2 x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + x_ref = x.clone() conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, weight, bias, activation=activation) - out_ref = causal_conv1d_update_ref(x, + out_ref = causal_conv1d_update_ref(x_ref, conv_state_ref, weight, bias, @@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - None, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, None, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", @@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, @pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, + seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - # set )seed + # set seed seed_everything(0) - batch = 64 - x = torch.randn(batch, dim, 1, device=device, dtype=itype) + batch_size = 3 + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + total_entries = 10 * batch_size - total_entries = 10 * batch + x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) + x_ref = x.clone() + + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat([ + conv_state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) conv_state = torch.randn(total_entries, dim, width - 1, device=device, dtype=itype) - conv_state_indices = torch.randperm(total_entries)[:batch].to( - dtype=torch.int32, device=device) + conv_state_for_padding_test = conv_state.clone() - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, weight, bias, activation=activation, - conv_state_indices=conv_state_indices) - out_ref = causal_conv1d_update_ref(x, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + assert torch.equal(conv_state[unused_states_bool], + conv_state_for_padding_test[unused_states_bool]) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - conv_state_indices, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', - [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +@pytest.mark.parametrize( + 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) -def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, - itype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize('with_padding', [True, False]) +def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, + silu_activation, itype): device = "cuda" + torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed seed_everything(0) - batch = 1 seqlens = [] - nsplits = 3 + batch_size = 4 + if seqlen < 10: + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + nsplits = padded_batch_size - 1 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( @@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) - x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, + x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None @@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(nsplits + 1, + final_states = torch.randn(total_entries, dim, width - 1, device=x.device, @@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=x.device) - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=x.device) + device=x.device)[:batch_size] + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) + out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation, PAD_SLOT_ID) out_ref = [] out_ref_b = [] splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] for i in range(len(seqlens[0])): x_s = [v[i].unsqueeze(0) for v in splits][0] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_b.append( causal_conv1d_ref( x_s, @@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[cache_indices[i]].unsqueeze( - 0), - initial_states=final_states_ref[cache_indices[i]].unsqueeze(0) - if has_initial_states[i] else None)) + final_states_out=final_states_ref[ + padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]]. + unsqueeze(0) if has_initial_states[i] else None)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) - out_ref = torch.cat(out_ref, dim=0) + out_ref_tensor = torch.cat(out_ref, dim=0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print("Output state max diff" - f":{(final_states - final_states_ref).abs().max()}") - print("Output state mean diff" - f":{(final_states - final_states_ref).abs().mean()}") - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 8fa55e75..e92d4013 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from einops import rearrange, repeat from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u, cu_seq_len=None, cache_indices=None, has_initial_state=None, - ssm_states=None): + ssm_states=None, + pad_slot_id=PAD_SLOT_ID): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u, # a bogus error. opcheck(torch.ops._C.selective_scan_fwd, (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states), + cache_indices, has_initial_state, ssm_states, pad_slot_id), test_utils=["test_schema", "test_faketensor"]) @@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, - has_D, has_z, has_delta_bias, delta_softplus, - return_last_state, seqlen, itype, wtype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [False, True]) +def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, + varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, + itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, # set seed torch.random.manual_seed(0) seqlens = [] - nsplits = 3 + batch_size = 4 if seqlen < 10: - nsplits = 0 + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + + if with_padding and seqlen < padded_batch_size: + pytest.skip() + + nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( torch.cat( [torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() @@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_ref = delta.clone() out = None out_ref = None - prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) + + prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) prev_state = torch.randn(prev_state_shape, device=u.device, dtype=itype, requires_grad=False) prev_state_ref = prev_state.clone() - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=u.device) + device=u.device)[:batch_size] + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) has_initial_state = torch.randint(0, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=u.device) out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state) outs_ref = [] splits = [ @@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, ] for i in range(len(seqlens[0])): u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_s, _ = selective_scan_ref( u_s, delta_s, @@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state, - prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0) + prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) if has_initial_state[i] else None, - final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0)) + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( + 0)) outs_ref.append(out_ref_s) - out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] + out_ref = torch.cat(outs_ref, dim=-1)[0] - print("Output diff max", (out - out_ref[0]).max()) - print("Output diff mean", (out - out_ref[0]).mean()) + unpadded_out = out[:, :out_ref[0].shape[-1]] + print("Output diff max", (unpadded_out - out_ref).max()) + print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) - + assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state, prev_state) @@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, + has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): # set seed torch.random.manual_seed(0) batch_size = 3 - + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( dtype=torch.int32, device=device) - - x = torch.randn(batch_size, dim, device=device, dtype=itype) - dt = torch.randn(batch_size, dim, device=device, dtype=itype) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) + x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 - B = torch.randn(batch_size, dstate, device=device) - C = torch.randn(batch_size, dstate, device=device) + B = torch.randn(padded_batch_size, dstate, device=device) + C = torch.randn(padded_batch_size, dstate, device=device) D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None - state_ref = state[state_indices, :].detach().clone() + state_ref = state[state_indices, :].clone() + state_before = state.clone() out = selective_state_update(state, x, dt, @@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, - x, - dt, + x[:batch_size], + dt[:batch_size], A, - B, - C, + B[:batch_size], + C[:batch_size], D=D, - z=z, + z=z[:batch_size], dt_bias=dt_bias, dt_softplus=True) @@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) + # test padded entries stay the same + if with_padding: + assert torch.equal(state_before[unused_states_bool], + state[unused_states_bool]) + assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) + assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) + assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) + assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + + # test "real" entries assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", @@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices( z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, x, dt, diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 408d12cd..384ec77e 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,5 +1,6 @@ import pytest +from tests.utils import multi_gpu_test from vllm.sampling_params import SamplingParams from vllm.worker.model_runner import _get_graph_batch_size @@ -270,6 +271,30 @@ def test_state_cleanup( "could be related to finished_requests_ids") +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_jamba_distributed_produces_identical_generation( + vllm_runner, model: str, dtype: str, max_tokens: int, + example_prompts) -> None: + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_tp_1, + outputs_1_lst=vllm_outputs_tp_2, + name_0="vllm_tp_1", + name_1="vllm_tp_2", + ) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_model_print( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a236922..ec035f13 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.empty_like(x) + silu_activation: bool, pad_slot_id: int): + return None @register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: - return torch.empty_like(x) + def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int) -> None: + return None @register_fake("_C::selective_scan_fwd") def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, @@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - ssm_states: Optional[torch.Tensor]) -> None: + ssm_states: Optional[torch.Tensor], + pad_slot_id: int) -> None: return None @@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, query_start_loc: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation) + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) -def causal_conv1d_update( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: - return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices) +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) -def selective_scan_fwd( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) # moe diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index ed7241af..be5639df 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -6,18 +6,18 @@ from typing import Optional import torch from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen @@ -37,6 +37,13 @@ def causal_conv1d_fn( conv_states: (...,dim,width - 1) itype updated inplace if provided activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) """ @@ -46,10 +53,10 @@ def causal_conv1d_fn( x = x.contiguous() bias = bias.contiguous() if bias is not None else None - out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, - cache_indices, has_initial_state, activation - in ["silu", "swish"]) - return out + ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, + cache_indices, has_initial_state, activation + in ["silu", "swish"], pad_slot_id) + return x def causal_conv1d_update(x: torch.Tensor, @@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Optional[str] = None, cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None): + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor, If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: @@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor, unsqueeze = x.dim() == 2 if unsqueeze: x = x.unsqueeze(-1) - out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, - cache_seqlens, conv_state_indices) + ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, + cache_seqlens, conv_state_indices, pad_slot_id) if unsqueeze: - out = out.squeeze(-1) - return out + x = x.squeeze(-1) + return x diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 08b016c2..1484b798 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,14 +1,13 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py -from typing import Tuple - import torch import triton import triton.language as tl from packaging import version from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") @@ -50,6 +49,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + pad_slot_id, # Matrix dimensions batch, nheads, @@ -143,10 +143,11 @@ def _selective_scan_update_kernel( if HAS_Z: z_ptrs = z_ptr + offs_m * stride_z_dim out_ptrs = out_ptr + offs_m * stride_out_dim + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + state = tl.load(state_ptrs, mask=mask, other=0.0) - state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -177,9 +178,11 @@ def _selective_scan_update_kernel( dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt state = state * dA + dB * x[:, None] - tl.store(state_ptrs, - state, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + tl.store(state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -198,7 +201,8 @@ def selective_state_update(state, z=None, dt_bias=None, dt_softplus=False, - state_batch_indices=None): + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -210,6 +214,12 @@ def selective_state_update(state, D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 Return: out: (batch, dim) or (batch, nheads, dim) """ @@ -276,6 +286,7 @@ def selective_state_update(state, z, out, state_batch_indices, + pad_slot_id, batch, nheads, dim, @@ -319,22 +330,25 @@ def selective_state_update(state, return out -def selective_scan_fn( - u, - ssm_states, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - query_start_loc=None, - cache_indices=None, - has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]: +def selective_scan_fn(u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) + applies changes in place. + ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) + applies changes in place. delta: (dim, total_length) for varlen or (batch, dim, seqlen) A: (dim, dstate) B: (ngroups, dstate, total_length) for varlen or @@ -357,12 +371,14 @@ def selective_scan_fn( indicate if the ssm_state at the corresponding index should be used as initial state. Not providing argument assumes there's no initial state - + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 returns output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement - last_state has shape (batch, dim, dstate). - supports inplace replacement if ssm_state was provided """ if u.stride(-1) != 1: u = u.contiguous() @@ -387,7 +403,7 @@ def selective_scan_fn( ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) if z is None: return delta # output written inplace to delta diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ac251b88..fddd39fb 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """Inference-only Jamba model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA KVCache = Tuple[torch.Tensor, torch.Tensor] -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module): **selective** state spaces) """ - def __init__(self, config: JambaConfig, layer_idx): + def __init__(self, config: JambaConfig): super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.ssm_state_size = config.mamba_d_state self.conv_kernel_size = config.mamba_d_conv @@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module): eps=config.rms_norm_eps) def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor): + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module): conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=conv_state, + conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - conv_state, + mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module): and attn_metadata.context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - ssm_state, + mamba_cache_params.ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module): gate, time_proj_bias, delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: scan_outputs = selective_state_update( - ssm_state, + mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, @@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module): gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - ) + state_batch_indices=mamba_cache_params.state_indices_tensor) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection @@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module): super().__init__() self.layer_idx = layer_idx self.config = config - self.mamba = JambaMambaMixer(config, layer_idx) + self.mamba = JambaMambaMixer(config) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP @@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, - ssm_state) + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -476,17 +469,14 @@ class JambaModel(nn.Module): positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None - for i in range(len(self.layers)): layer = self.layers[i] kv_cache = None - current_ssm_state = None - current_conv_state = None + layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] @@ -494,8 +484,8 @@ class JambaModel(nn.Module): current_state_layer = i - (1 + (i - self.config.attn_layer_offset) // self.config.attn_layer_period) - current_ssm_state = ssm_state[current_state_layer] - current_conv_state = conv_state[current_state_layer] + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + current_state_layer) hidden_states, residual = layer( positions=positions, @@ -503,9 +493,7 @@ class JambaModel(nn.Module): kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) + mamba_cache_params=layer_mamba_cache_params) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) - - mamba_cache_tensors = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) - + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_tensors[0], - mamba_cache_tensors[1]) + attn_metadata, mamba_cache_params) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index b86b687a..7f2efb98 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) -from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -110,8 +111,8 @@ class MambaMixer(nn.Module): self.activation = config.hidden_act def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor): + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -134,17 +135,18 @@ class MambaMixer(nn.Module): conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=conv_state, + conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - conv_state, + mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -168,7 +170,7 @@ class MambaMixer(nn.Module): and attn_metadata.context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - ssm_state, + mamba_cache_params.ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -177,11 +179,12 @@ class MambaMixer(nn.Module): gate, time_proj_bias, delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: scan_outputs = selective_state_update( - ssm_state, + mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, @@ -191,7 +194,7 @@ class MambaMixer(nn.Module): gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - ) + state_batch_indices=mamba_cache_params.state_indices_tensor) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection @@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module): hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, - ssm_state) + hidden_states = self.mixer(hidden_states, attn_metadata, + mamba_cache_params) return hidden_states, residual @@ -275,25 +277,20 @@ class MambaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] - current_ssm_state = ssm_state[i] - current_conv_state = conv_state[i] - hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) + mamba_cache_params=mamba_cache_params.at_layer_idx(i)) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states @@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): self.lm_head.weight.dtype, self.config.num_hidden_layers, max_batch_size, *self._get_mamba_cache_shape()) - mamba_cache_tensors = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_tensors[0], - mamba_cache_tensors[1]) + mamba_cache_params) return hidden_states diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 8d1ba373..79393421 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -1,8 +1,22 @@ -from typing import Dict, List, Optional +from dataclasses import dataclass +from typing import Dict, List import torch from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.utils import PAD_SLOT_ID + + +@dataclass +class MambaCacheParams: + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self, layer_idx): + return MambaCacheParams(self.conv_state[layer_idx], + self.ssm_state[layer_idx], + self.state_indices_tensor) class MambaCacheManager: @@ -24,6 +38,7 @@ class MambaCacheManager: # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.free_cache_indices = list(range(max_batch_size)) def current_run_tensors(self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs): @@ -36,30 +51,43 @@ class MambaCacheManager: finished_requests_ids = kwargs["finished_requests_ids"] self._release_finished_requests(finished_requests_ids) - mamba_cache_tensors = self._prepare_current_run_mamba_cache( + state_indices = self._prepare_current_run_mamba_cache( request_ids_to_seq_ids, finished_requests_ids) + state_indices_tensor = torch.as_tensor(state_indices, + dtype=torch.int32, + device="cuda") + mamba_cache_tensors = self.mamba_cache + else: # CUDA graph capturing runs - mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + (mamba_cache_tensors, + state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return mamba_cache_tensors + return (mamba_cache_tensors, state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (JambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant state_indices into the CUDA graph input buffer """ assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) finished_requests_ids = kwargs["finished_requests_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + assert "seqlen_agnostic_capture_inputs" in input_buffers + _, input_state_indices_buffer = input_buffers[ + "seqlen_agnostic_capture_inputs"] self._release_finished_requests(finished_requests_ids) - self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - finished_requests_ids) + state_indices = self._prepare_current_run_mamba_cache( + request_ids_to_seq_ids, finished_requests_ids) + cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( + state_indices) + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + + input_state_indices_buffer.copy_( + torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ @@ -67,13 +95,10 @@ class MambaCacheManager: The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) - - def _swap_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, [to_index,from_index]] = \ - cache_t[:, [from_index,to_index]] + state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") + return (self.mamba_cache, state_indices_tensor) def _copy_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 @@ -81,142 +106,53 @@ class MambaCacheManager: cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) - def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): - if index in all_occupied_indices: - first_free_index = self._first_free_index_in_mamba_cache() - # In case occupied, move the occupied to a new empty block - self._move_cache_index_and_mappings(from_index=index, - to_index=first_free_index) - - def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, - seq_id: int, - destination_index: int): + def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, + finished_requests_ids) -> int: """ Assign (req_id,seq_id) pair to a `destination_index` index, if already occupied, move the occupying index to a free index. """ - all_occupied_indices = self._get_all_occupied_indices() - if cur_rid not in self.mamba_cache_indices_mapping: - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) + if cur_rid in finished_requests_ids: + # set as pad, do not allocate destination index + return PAD_SLOT_ID + elif cur_rid not in self.mamba_cache_indices_mapping: + destination_index = self.free_cache_indices.pop() self.mamba_cache_indices_mapping[cur_rid] = { seq_id: destination_index } + return destination_index elif seq_id not in (seq_ids2indices := self.mamba_cache_indices_mapping[cur_rid]): # parallel sampling , where n > 1, assume prefill have - # already happened now we only need to copy the already + # already happened, so we copy the # existing cache into the siblings seq_ids caches - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - index_exists = list(seq_ids2indices.values())[0] + index_exists = next(iter(seq_ids2indices.values())) # case of decoding n>1, copy prefill cache to decoding indices + destination_index = self.free_cache_indices.pop() self._copy_mamba_cache(from_index=index_exists, to_index=destination_index) self.mamba_cache_indices_mapping[cur_rid][ seq_id] = destination_index + return destination_index else: # already exists - cache_index_already_exists = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - if cache_index_already_exists != destination_index: - # In case the seq id already exists but not in - # the right destination, swap it with what's occupying it - self._swap_pair_indices_and_mappings( - from_index=cache_index_already_exists, - to_index=destination_index) + return self.mamba_cache_indices_mapping[cur_rid][seq_id] def _prepare_current_run_mamba_cache( self, request_ids_to_seq_ids: Dict[str, list[int]], - finished_requests_ids: List[str]): - running_indices = [] - request_ids_to_seq_ids_flatten = [ - (req_id, seq_id) + finished_requests_ids: List[str]) -> List[int]: + return [ + self._assign_seq_id_to_cache_index(req_id, seq_id, + finished_requests_ids) for req_id, seq_ids in request_ids_to_seq_ids.items() for seq_id in seq_ids ] - batch_size = len(request_ids_to_seq_ids_flatten) - for dest_index, (request_id, - seq_id) in enumerate(request_ids_to_seq_ids_flatten): - if request_id in finished_requests_ids: - # Do not allocate cache index for requests that run - # and finish right after - continue - self._assign_seq_id_to_mamba_cache_in_specific_dest( - request_id, seq_id, dest_index) - running_indices.append(dest_index) - - self._clean_up_first_bs_blocks(batch_size, running_indices) - conv_state = self.mamba_cache[0][:, :batch_size] - temporal_state = self.mamba_cache[1][:, :batch_size] - - return (conv_state, temporal_state) - - def _get_all_occupied_indices(self): - return [ - cache_idx - for seq_ids2indices in self.mamba_cache_indices_mapping.values() - for cache_idx in seq_ids2indices.values() - ] - - def _clean_up_first_bs_blocks(self, batch_size: int, - indices_for_current_run: List[int]): - # move out all of the occupied but currently not running blocks - # outside of the first n blocks - destination_indices = range(batch_size) - max_possible_batch_size = self.mamba_cache[0].shape[1] - for destination_index in destination_indices: - if destination_index in self._get_all_occupied_indices() and \ - destination_index not in indices_for_current_run: - # move not running indices outside of the batch - all_other_indices = list( - range(batch_size, max_possible_batch_size)) - first_avail_index = self._first_free_index_in_mamba_cache( - all_other_indices) - self._swap_indices(from_index=destination_index, - to_index=first_avail_index) - - def _move_cache_index_and_mappings(self, from_index: int, to_index: int): - self._copy_mamba_cache(from_index=from_index, to_index=to_index) - self._update_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): - self._swap_mamba_cache(from_index=from_index, to_index=to_index) - self._swap_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - elif to_index == index: - seq_ids2index.update({seq_id: from_index}) - - def _update_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - return def _release_finished_requests(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.mamba_cache_indices_mapping: + for seq_id in self.mamba_cache_indices_mapping[req_id]: + self.free_cache_indices.append( + self.mamba_cache_indices_mapping[req_id][seq_id]) self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache( - self, indices_range: Optional[List[int]] = None) -> int: - assert self.mamba_cache is not None - if indices_range is None: - max_possible_batch_size = self.mamba_cache[0].shape[1] - indices_range = list(range(max_possible_batch_size)) - all_occupied_indices = self._get_all_occupied_indices() - for i in indices_range: - if i not in all_occupied_indices: - return i - raise Exception("Couldn't find a free spot in the mamba cache! This" - "should never happen")