[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman 2024-10-17 00:12:43 +08:00 committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 504 additions and 432 deletions

View File

@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor out,
const c10::optional<at::Tensor>& bias,
bool silu_activation,
int64_t pad_slot_id,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.dim = dim;
params.seqlen = seqlen;
params.width = width;
params.pad_slot_id = pad_slot_id;
params.silu_activation = silu_activation;
@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase &params,
}
at::Tensor
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation) {
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE(cache_indices_, batch_size);
}
at::Tensor out = torch::empty_like(x);
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id,
query_start_loc,
cache_indices,
has_initial_state
@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
});
return out;
}
at::Tensor
causal_conv1d_update(const at::Tensor &x,
void causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation,
const c10::optional<at::Tensor> &cache_seqlens_,
const c10::optional<at::Tensor> &conv_state_indices_) {
const c10::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(bias, dim);
}
at::Tensor out = torch::empty_like(x);
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation);
silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes.
@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
});
return out;
}
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;

View File

@ -13,6 +13,7 @@ struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation;
index_t x_batch_stride;

View File

@ -21,6 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
int64_t pad_slot_id;
bool delta_softplus;

View File

@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const size_t seqlen,
const size_t dstate,
const size_t n_groups,
const size_t n_chunks,
const bool is_variable_B,
const bool is_variable_C,
// device pointers
@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool varlen) {
bool varlen,
int64_t pad_slot_id) {
// Reset the parameters
memset(&params, 0, sizeof(params));
@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.seqlen = seqlen;
params.dstate = dstate;
params.n_groups = n_groups;
params.n_chunks = n_chunks;
params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus;
@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states) {
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
out_z = z;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1);
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);
SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z,
D_,
delta_bias_,
@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
query_start_loc,
cache_indices,
has_initial_state,
varlen
varlen,
pad_slot_id
);

View File

@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states);
const torch::Tensor& ssm_states, int64_t pad_slot_id);
at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_);
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
#ifndef USE_ROCM
using fptr_t = int64_t;

View File

@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()");
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor");
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor");
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
@ -114,16 +115,15 @@ def causal_conv1d_update_ref(x,
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
def causal_conv1d_opcheck_fn(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn(
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
opcheck(torch.ops._C.causal_conv1d_fwd, (
x,
weight,
bias,
conv_states,
cu_seq_len,
cache_indices,
has_initial_state,
activation in ["silu", "swish"],
))
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
seed_everything(0)
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_ref(x,
out_ref = causal_conv1d_update_ref(x_ref,
conv_state_ref,
weight,
bias,
@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update, (
x,
conv_state,
weight,
bias,
activation in ["silu", "swish"],
None,
None,
))
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, None, PAD_SLOT_ID))
@pytest.mark.parametrize("itype",
@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
seqlen, has_bias,
silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set )seed
# set seed
seed_everything(0)
batch = 64
x = torch.randn(batch, dim, 1, device=device, dtype=itype)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
total_entries = 10 * batch
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
conv_state = torch.randn(total_entries,
dim,
width - 1,
device=device,
dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to(
dtype=torch.int32, device=device)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices)
out_ref = causal_conv1d_update_ref(x,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
opcheck(torch.ops._C.causal_conv1d_update, (
x,
conv_state,
weight,
bias,
activation in ["silu", "swish"],
None,
conv_state_indices,
))
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen',
[8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize(
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
itype):
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
silu_activation, itype):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
batch = 1
seqlens = []
nsplits = 3
batch_size = 4
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device,
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(nsplits + 1,
final_states = torch.randn(total_entries,
dim,
width - 1,
device=x.device,
@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
cache_indices = torch.randperm(cumsum.shape[0] - 1,
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=x.device)
device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)
padded_state_indices, has_initial_states,
final_states, activation, PAD_SLOT_ID)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[cache_indices[i]].unsqueeze(
0),
initial_states=final_states_ref[cache_indices[i]].unsqueeze(0)
if has_initial_states[i] else None))
final_states_out=final_states_ref[
padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].
unsqueeze(0) if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref = torch.cat(out_ref, dim=0)
out_ref_tensor = torch.cat(out_ref, dim=0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print("Output state max diff"
f":{(final_states - final_states_ref).abs().max()}")
print("Output state mean diff"
f":{(final_states - final_states_ref).abs().mean()}")
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)
padded_state_indices, has_initial_states,
final_states, activation)

View File

@ -5,6 +5,7 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
cu_seq_len=None,
cache_indices=None,
has_initial_state=None,
ssm_states=None):
ssm_states=None,
pad_slot_id=PAD_SLOT_ID):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
cache_indices, has_initial_state, ssm_states),
cache_indices, has_initial_state, ssm_states, pad_slot_id),
test_utils=["test_schema", "test_faketensor"])
@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
has_D, has_z, has_delta_bias, delta_softplus,
return_last_state, seqlen, itype, wtype):
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen,
itype, wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
# set seed
torch.random.manual_seed(0)
seqlens = []
nsplits = 3
batch_size = 4
if seqlen < 10:
nsplits = 0
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
if with_padding and seqlen < padded_batch_size:
pytest.skip()
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()
@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_ref = delta.clone()
out = None
out_ref = None
prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1]))
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
prev_state_ref = prev_state.clone()
cache_indices = torch.randperm(cumsum.shape[0] - 1,
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=u.device)
device=u.device)[:batch_size]
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
delta_softplus, cumsum, padded_state_indices,
has_initial_state)
outs_ref = []
splits = [
@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
]
for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_s, _ = selective_scan_ref(
u_s,
delta_s,
@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0)
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_state[i] else None,
final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0))
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
0))
outs_ref.append(out_ref_s)
out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0]
out_ref = torch.cat(outs_ref, dim=-1)[0]
print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
unpadded_out = out[:, :out_ref[0].shape[-1]]
print("Output diff max", (unpadded_out - out_ref).max())
print("Output diff mean", (unpadded_out - out_ref).mean())
print("Output state diff max", (prev_state - prev_state_ref).max())
print("Output state diff mean", (prev_state - prev_state_ref).mean())
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol)
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
delta_softplus, cumsum, padded_state_indices,
has_initial_state, prev_state)
@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# set seed
torch.random.manual_seed(0)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
state_ref = state[state_indices, :].clone()
state_before = state.clone()
out = selective_state_update(state,
x,
dt,
@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
x[:batch_size],
dt[:batch_size],
A,
B,
C,
B[:batch_size],
C[:batch_size],
D=D,
z=z,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True)
@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool],
state[unused_states_bool])
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
# test "real" entries
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype",
@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref,
x,
dt,

View File

@ -1,5 +1,6 @@
import pytest
from tests.utils import multi_gpu_test
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size
@ -270,6 +271,30 @@ def test_state_cleanup(
"could be related to finished_requests_ids")
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_jamba_distributed_produces_identical_generation(
vllm_runner, model: str, dtype: str, max_tokens: int,
example_prompts) -> None:
with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
max_tokens)
with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_outputs_tp_1,
outputs_1_lst=vllm_outputs_tp_2,
name_0="vllm_tp_1",
name_1="vllm_tp_2",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(

View File

@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
silu_activation: bool, pad_slot_id: int):
return None
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor]) -> None:
ssm_states: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation)
silu_activation: bool, pad_slot_id: int):
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens,
conv_state_indices)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int):
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens,
conv_state_indices, pad_slot_id)
def selective_scan_fwd(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor):
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool,
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: torch.Tensor, pad_slot_id: int):
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
delta_softplus, query_start_loc,
cache_indices, has_initial_state,
ssm_states)
ssm_states, pad_slot_id)
# moe

View File

@ -6,18 +6,18 @@ from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
def causal_conv1d_fn(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
@ -37,6 +37,13 @@ def causal_conv1d_fn(
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
@ -46,10 +53,10 @@ def causal_conv1d_fn(
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
cache_indices, has_initial_state, activation
in ["silu", "swish"])
return out
ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
cache_indices, has_initial_state, activation
in ["silu", "swish"], pad_slot_id)
return x
def causal_conv1d_update(x: torch.Tensor,
@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None):
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
cache_seqlens, conv_state_indices)
ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
cache_seqlens, conv_state_indices, pad_slot_id)
if unsqueeze:
out = out.squeeze(-1)
return out
x = x.squeeze(-1)
return x

View File

@ -1,14 +1,13 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr,
out_ptr,
state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions
batch,
nheads,
@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
state = tl.load(state_ptrs, mask=mask, other=0.0)
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
tl.store(state_ptrs,
state,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
@ -198,7 +201,8 @@ def selective_state_update(state,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None):
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@ -210,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
@ -276,6 +286,7 @@ def selective_state_update(state,
z,
out,
state_batch_indices,
pad_slot_id,
batch,
nheads,
dim,
@ -319,22 +330,25 @@ def selective_state_update(state,
return out
def selective_scan_fn(
u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]:
def selective_scan_fn(u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
@ -357,12 +371,14 @@ def selective_scan_fn(
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
last_state has shape (batch, dim, dstate).
supports inplace replacement if ssm_state was provided
"""
if u.stride(-1) != 1:
u = u.contiguous()
@ -387,7 +403,7 @@ def selective_scan_fn(
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
query_start_loc, cache_indices, has_initial_state,
ssm_states)
ssm_states, pad_slot_id)
if z is None:
return delta # output written inplace to delta

View File

@ -1,6 +1,5 @@
# coding=utf-8
"""Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple
import torch
@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class JambaMambaMixer(nn.Module):
"""
@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces)
"""
def __init__(self, config: JambaConfig, layer_idx):
def __init__(self, config: JambaConfig):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
eps=config.rms_norm_eps)
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.mamba = JambaMambaMixer(config, layer_idx)
self.mamba = JambaMambaMixer(config)
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
ssm_state)
hidden_states = self.mamba(hidden_states, attn_metadata,
mamba_cache_params)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
@ -476,17 +469,14 @@ class JambaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = None
current_ssm_state = None
current_conv_state = None
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period]
@ -494,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer = i - (1 +
(i - self.config.attn_layer_offset)
// self.config.attn_layer_period)
current_ssm_state = ssm_state[current_state_layer]
current_conv_state = conv_state[current_state_layer]
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
hidden_states, residual = layer(
positions=positions,
@ -503,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
conv_state=current_conv_state,
ssm_state=current_ssm_state,
)
mamba_cache_params=layer_mamba_cache_params)
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
mamba_cache_tensors[1])
attn_metadata, mamba_cache_params)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):

View File

@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
self.activation = config.hidden_act
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params)
return hidden_states, residual
@ -275,25 +277,20 @@ class MambaModel(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
current_ssm_state = ssm_state[i]
current_conv_state = conv_state[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
conv_state=current_conv_state,
ssm_state=current_ssm_state,
)
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.lm_head.weight.dtype, self.config.num_hidden_layers,
max_batch_size, *self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0],
mamba_cache_tensors[1])
mamba_cache_params)
return hidden_states

View File

@ -1,8 +1,22 @@
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Dict, List
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager:
@ -24,6 +38,7 @@ class MambaCacheManager:
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
def current_run_tensors(self, input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, **kwargs):
@ -36,30 +51,43 @@ class MambaCacheManager:
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
mamba_cache_tensors = self._prepare_current_run_mamba_cache(
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return mamba_cache_tensors
return (mamba_cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
@ -67,13 +95,10 @@ class MambaCacheManager:
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
@ -81,142 +106,53 @@ class MambaCacheManager:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]):
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache index for requests that run
# and finish right after
continue
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")