[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
This commit is contained in:
parent
415f76a9cb
commit
fb60ae9b91
@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
const at::Tensor out,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &conv_states,
|
||||
const c10::optional<at::Tensor> &query_start_loc,
|
||||
const c10::optional<at::Tensor> &cache_indices,
|
||||
const c10::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation) {
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_update(const at::Tensor &x,
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor> &cache_seqlens_,
|
||||
const c10::optional<at::Tensor> &conv_state_indices_) {
|
||||
const c10::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation);
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
@ -13,6 +13,7 @@ struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
|
@ -21,6 +21,7 @@ struct SSMParamsBase {
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
int64_t pad_slot_id;
|
||||
|
||||
bool delta_softplus;
|
||||
|
||||
|
@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
|
||||
@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen) {
|
||||
bool varlen,
|
||||
int64_t pad_slot_id) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const c10::optional<torch::Tensor> &query_start_loc,
|
||||
const c10::optional<torch::Tensor> &cache_indices,
|
||||
const c10::optional<torch::Tensor> &has_initial_state,
|
||||
const torch::Tensor &ssm_states) {
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
out_z = z;
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(ssm_states.is_cuda());
|
||||
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
||||
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_,
|
||||
delta_bias_,
|
||||
@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
varlen
|
||||
varlen,
|
||||
pad_slot_id
|
||||
);
|
||||
|
||||
|
||||
|
28
csrc/ops.h
28
csrc/ops.h
@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const c10::optional<torch::Tensor>& query_start_loc,
|
||||
const c10::optional<torch::Tensor>& cache_indices,
|
||||
const c10::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states);
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
at::Tensor causal_conv1d_update(
|
||||
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_, bool silu_activation,
|
||||
const c10::optional<at::Tensor>& cache_seqlens_,
|
||||
const c10::optional<at::Tensor>& conv_state_indices_);
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor>& cache_seqlens_,
|
||||
const c10::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& conv_states,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation);
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& conv_states,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
|
@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"Tensor! ssm_states) -> ()");
|
||||
"Tensor! ssm_states,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices) -> Tensor");
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
||||
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import seed_everything
|
||||
@ -114,16 +115,15 @@ def causal_conv1d_update_ref(x,
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
def causal_conv1d_opcheck_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cu_seq_len: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cu_seq_len: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn(
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_fwd, (
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
conv_states,
|
||||
cu_seq_len,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
activation in ["silu", "swish"],
|
||||
))
|
||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
||||
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
|
||||
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype,
|
||||
requires_grad=True)
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
bias = None
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
out_ref = causal_conv1d_update_ref(x,
|
||||
out_ref = causal_conv1d_update_ref(x_ref,
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update, (
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation in ["silu", "swish"],
|
||||
None,
|
||||
None,
|
||||
))
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, None, PAD_SLOT_ID))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set )seed
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
batch = 64
|
||||
|
||||
x = torch.randn(batch, dim, 1, device=device, dtype=itype)
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
total_entries = 10 * batch
|
||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
conv_state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
conv_state = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch].to(
|
||||
dtype=torch.int32, device=device)
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype,
|
||||
requires_grad=True)
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
bias = None
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=conv_state_indices)
|
||||
out_ref = causal_conv1d_update_ref(x,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.equal(conv_state[unused_states_bool],
|
||||
conv_state_for_padding_test[unused_states_bool])
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update, (
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation in ["silu", "swish"],
|
||||
None,
|
||||
conv_state_indices,
|
||||
))
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize('seqlen',
|
||||
[8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
|
||||
@pytest.mark.parametrize('dim', [64, 4096])
|
||||
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
|
||||
itype):
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize('with_padding', [True, False])
|
||||
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
batch = 1
|
||||
seqlens = []
|
||||
nsplits = 3
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
batch_size = 1
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0)
|
||||
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device,
|
||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
|
||||
dtype=itype)[:, 4096:4096 + dim, :]
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(nsplits + 1,
|
||||
final_states = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
device=x.device,
|
||||
@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
dtype=torch.bool,
|
||||
device=x.device)
|
||||
cache_indices = torch.randperm(cumsum.shape[0] - 1,
|
||||
state_indices = torch.randperm(total_entries,
|
||||
dtype=torch.int32,
|
||||
device=x.device)
|
||||
device=x.device)[:batch_size]
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
cache_indices, has_initial_states, final_states,
|
||||
activation)
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation, PAD_SLOT_ID)
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
|
||||
bias_ref,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=final_states_ref[cache_indices[i]].unsqueeze(
|
||||
0),
|
||||
initial_states=final_states_ref[cache_indices[i]].unsqueeze(0)
|
||||
if has_initial_states[i] else None))
|
||||
final_states_out=final_states_ref[
|
||||
padded_state_indices[i]].unsqueeze(0),
|
||||
initial_states=final_states_ref[padded_state_indices[i]].
|
||||
unsqueeze(0) if has_initial_states[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref = torch.cat(out_ref, dim=0)
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print("Output state max diff"
|
||||
f":{(final_states - final_states_ref).abs().max()}")
|
||||
print("Output state mean diff"
|
||||
f":{(final_states - final_states_ref).abs().mean()}")
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
cache_indices, has_initial_states, final_states,
|
||||
activation)
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation)
|
||||
|
@ -5,6 +5,7 @@ from einops import rearrange, repeat
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.utils import seed_everything
|
||||
@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
|
||||
cu_seq_len=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
ssm_states=None):
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
"""
|
||||
@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
|
||||
# a bogus error.
|
||||
opcheck(torch.ops._C.selective_scan_fwd,
|
||||
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
|
||||
cache_indices, has_initial_state, ssm_states),
|
||||
cache_indices, has_initial_state, ssm_states, pad_slot_id),
|
||||
test_utils=["test_schema", "test_faketensor"])
|
||||
|
||||
|
||||
@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
|
||||
has_D, has_z, has_delta_bias, delta_softplus,
|
||||
return_last_state, seqlen, itype, wtype):
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [False, True])
|
||||
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
|
||||
varBC_groups, has_D, has_z, has_delta_bias,
|
||||
delta_softplus, return_last_state, seqlen,
|
||||
itype, wtype):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seqlens = []
|
||||
nsplits = 3
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
nsplits = 0
|
||||
batch_size = 1
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
|
||||
if with_padding and seqlen < padded_batch_size:
|
||||
pytest.skip()
|
||||
|
||||
nsplits = padded_batch_size - 1
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat(
|
||||
[torch.tensor([-1]), eos_pos,
|
||||
torch.tensor([seqlen - 1])])).tolist())
|
||||
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0).cuda()
|
||||
@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
|
||||
delta_ref = delta.clone()
|
||||
out = None
|
||||
out_ref = None
|
||||
prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1]))
|
||||
|
||||
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
|
||||
prev_state = torch.randn(prev_state_shape,
|
||||
device=u.device,
|
||||
dtype=itype,
|
||||
requires_grad=False)
|
||||
prev_state_ref = prev_state.clone()
|
||||
cache_indices = torch.randperm(cumsum.shape[0] - 1,
|
||||
state_indices = torch.randperm(total_entries,
|
||||
dtype=torch.int32,
|
||||
device=u.device)
|
||||
device=u.device)[:batch_size]
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
has_initial_state = torch.randint(0,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
dtype=torch.bool,
|
||||
device=u.device)
|
||||
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, cumsum, cache_indices,
|
||||
delta_softplus, cumsum, padded_state_indices,
|
||||
has_initial_state)
|
||||
outs_ref = []
|
||||
splits = [
|
||||
@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
|
||||
]
|
||||
for i in range(len(seqlens[0])):
|
||||
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_s, _ = selective_scan_ref(
|
||||
u_s,
|
||||
delta_s,
|
||||
@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state,
|
||||
prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0)
|
||||
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_state[i] else None,
|
||||
final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0))
|
||||
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
|
||||
0))
|
||||
outs_ref.append(out_ref_s)
|
||||
out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0]
|
||||
out_ref = torch.cat(outs_ref, dim=-1)[0]
|
||||
|
||||
print("Output diff max", (out - out_ref[0]).max())
|
||||
print("Output diff mean", (out - out_ref[0]).mean())
|
||||
unpadded_out = out[:, :out_ref[0].shape[-1]]
|
||||
print("Output diff max", (unpadded_out - out_ref).max())
|
||||
print("Output diff mean", (unpadded_out - out_ref).mean())
|
||||
print("Output state diff max", (prev_state - prev_state_ref).max())
|
||||
print("Output state diff mean", (prev_state - prev_state_ref).mean())
|
||||
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol)
|
||||
|
||||
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
|
||||
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, cumsum, cache_indices,
|
||||
delta_softplus, cumsum, padded_state_indices,
|
||||
has_initial_state, prev_state)
|
||||
|
||||
|
||||
@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 3
|
||||
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
B = torch.randn(padded_batch_size, dstate, device=device)
|
||||
C = torch.randn(padded_batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
state_ref = state[state_indices, :].clone()
|
||||
state_before = state.clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices)
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B[:batch_size],
|
||||
C[:batch_size],
|
||||
D=D,
|
||||
z=z,
|
||||
z=z[:batch_size],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
||||
print("Output state diff max", (state[state_indices, :] - state_ref).max())
|
||||
print("Output state diff mean",
|
||||
(state[state_indices, :] - state_ref).mean())
|
||||
# test padded entries stay the same
|
||||
if with_padding:
|
||||
assert torch.equal(state_before[unused_states_bool],
|
||||
state[unused_states_bool])
|
||||
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
|
||||
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
|
||||
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
|
||||
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
|
||||
|
||||
# test "real" entries
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices)
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
@ -270,6 +271,30 @@ def test_state_cleanup(
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_jamba_distributed_produces_identical_generation(
|
||||
vllm_runner, model: str, dtype: str, max_tokens: int,
|
||||
example_prompts) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model:
|
||||
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model:
|
||||
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_outputs_tp_1,
|
||||
outputs_1_lst=vllm_outputs_tp_2,
|
||||
name_0="vllm_tp_1",
|
||||
name_1="vllm_tp_2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_model_print(
|
||||
|
@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
cu_seq_len: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
silu_activation: bool, pad_slot_id: int):
|
||||
return None
|
||||
|
||||
@register_fake("_C::causal_conv1d_update")
|
||||
def causal_conv1d_update_fake(
|
||||
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor], silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
pad_slot_id: int) -> None:
|
||||
return None
|
||||
|
||||
@register_fake("_C::selective_scan_fwd")
|
||||
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
|
||||
@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
cu_seq_len: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
ssm_states: Optional[torch.Tensor]) -> None:
|
||||
ssm_states: Optional[torch.Tensor],
|
||||
pad_slot_id: int) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
||||
query_start_loc: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
|
||||
query_start_loc, cache_indices,
|
||||
has_initial_state, silu_activation)
|
||||
silu_activation: bool, pad_slot_id: int):
|
||||
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
|
||||
query_start_loc, cache_indices,
|
||||
has_initial_state, silu_activation,
|
||||
pad_slot_id)
|
||||
|
||||
|
||||
def causal_conv1d_update(
|
||||
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor], silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
||||
silu_activation, cache_seqlens,
|
||||
conv_state_indices)
|
||||
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
pad_slot_id: int):
|
||||
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
||||
silu_activation, cache_seqlens,
|
||||
conv_state_indices, pad_slot_id)
|
||||
|
||||
|
||||
def selective_scan_fwd(
|
||||
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor,
|
||||
C: torch.Tensor, D_: Optional[torch.Tensor],
|
||||
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
|
||||
delta_softplus: bool, query_start_loc: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor):
|
||||
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||
B: torch.Tensor, C: torch.Tensor,
|
||||
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
||||
delta_bias_: Optional[torch.Tensor],
|
||||
delta_softplus: bool,
|
||||
query_start_loc: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
ssm_states: torch.Tensor, pad_slot_id: int):
|
||||
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
|
||||
delta_softplus, query_start_loc,
|
||||
cache_indices, has_initial_state,
|
||||
ssm_states)
|
||||
ssm_states, pad_slot_id)
|
||||
|
||||
|
||||
# moe
|
||||
|
@ -6,18 +6,18 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
def causal_conv1d_fn(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
@ -37,6 +37,13 @@ def causal_conv1d_fn(
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
@ -46,10 +53,10 @@ def causal_conv1d_fn(
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
|
||||
cache_indices, has_initial_state, activation
|
||||
in ["silu", "swish"])
|
||||
return out
|
||||
ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
|
||||
cache_indices, has_initial_state, activation
|
||||
in ["silu", "swish"], pad_slot_id)
|
||||
return x
|
||||
|
||||
|
||||
def causal_conv1d_update(x: torch.Tensor,
|
||||
@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None):
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
|
||||
cache_seqlens, conv_state_indices)
|
||||
ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
|
||||
cache_seqlens, conv_state_indices, pad_slot_id)
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
x = x.squeeze(-1)
|
||||
return x
|
||||
|
@ -1,14 +1,13 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
||||
|
||||
@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
state = tl.load(state_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0)
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
tl.store(state_ptrs,
|
||||
state,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
@ -198,7 +201,8 @@ def selective_state_update(state,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None):
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
@ -210,6 +214,12 @@ def selective_state_update(state,
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
@ -276,6 +286,7 @@ def selective_state_update(state,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
@ -319,22 +330,25 @@ def selective_state_update(state,
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_fn(
|
||||
u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def selective_scan_fn(u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
@ -357,12 +371,14 @@ def selective_scan_fn(
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
last_state has shape (batch, dim, dstate).
|
||||
supports inplace replacement if ssm_state was provided
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
@ -387,7 +403,7 @@ def selective_scan_fn(
|
||||
|
||||
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
|
||||
query_start_loc, cache_indices, has_initial_state,
|
||||
ssm_states)
|
||||
ssm_states, pad_slot_id)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
|
@ -1,6 +1,5 @@
|
||||
# coding=utf-8
|
||||
"""Inference-only Jamba model."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCacheParams:
|
||||
is_prompt: bool = False
|
||||
conv_state: torch.Tensor = torch.Tensor()
|
||||
ssm_state: torch.Tensor = torch.Tensor()
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
class JambaMambaMixer(nn.Module):
|
||||
"""
|
||||
@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self, config: JambaConfig, layer_idx):
|
||||
def __init__(self, config: JambaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.ssm_state_size = config.mamba_d_state
|
||||
self.conv_kernel_size = config.mamba_d_conv
|
||||
@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor):
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
conv_state,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
self.mamba = JambaMambaMixer(config, layer_idx)
|
||||
self.mamba = JambaMambaMixer(config)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
|
||||
ssm_state)
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
||||
mamba_cache_params)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -476,17 +469,14 @@ class JambaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
current_ssm_state = None
|
||||
current_conv_state = None
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
|
||||
self.config.attn_layer_period]
|
||||
@ -494,8 +484,8 @@ class JambaModel(nn.Module):
|
||||
current_state_layer = i - (1 +
|
||||
(i - self.config.attn_layer_offset)
|
||||
// self.config.attn_layer_period)
|
||||
current_ssm_state = ssm_state[current_state_layer]
|
||||
current_conv_state = conv_state[current_state_layer]
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
@ -503,9 +493,7 @@ class JambaModel(nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
conv_state=current_conv_state,
|
||||
ssm_state=current_ssm_state,
|
||||
)
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
||||
**kwargs)
|
||||
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
attn_metadata, mamba_cache_params)
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
|
@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
|
||||
self.activation = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor):
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
conv_state,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
|
||||
ssm_state)
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||
mamba_cache_params)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -275,25 +277,20 @@ class MambaModel(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
current_ssm_state = ssm_state[i]
|
||||
current_conv_state = conv_state[i]
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
conv_state=current_conv_state,
|
||||
ssm_state=current_ssm_state,
|
||||
)
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
self.lm_head.weight.dtype, self.config.num_hidden_layers,
|
||||
max_batch_size, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
mamba_cache_params)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,8 +1,22 @@
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCacheParams:
|
||||
conv_state: torch.Tensor = torch.Tensor()
|
||||
ssm_state: torch.Tensor = torch.Tensor()
|
||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
||||
|
||||
def at_layer_idx(self, layer_idx):
|
||||
return MambaCacheParams(self.conv_state[layer_idx],
|
||||
self.ssm_state[layer_idx],
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MambaCacheManager:
|
||||
@ -24,6 +38,7 @@ class MambaCacheManager:
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
|
||||
def current_run_tensors(self, input_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, **kwargs):
|
||||
@ -36,30 +51,43 @@ class MambaCacheManager:
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
mamba_cache_tensors = self._prepare_current_run_mamba_cache(
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
mamba_cache_tensors = self.mamba_cache
|
||||
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
(mamba_cache_tensors,
|
||||
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return mamba_cache_tensors
|
||||
return (mamba_cache_tensors, state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant Mamba cache into the CUDA graph input buffer
|
||||
that was provided during the capture runs
|
||||
(JambaForCausalLM.mamba_gc_cache_buffer).
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
@ -67,13 +95,10 @@ class MambaCacheManager:
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
|
||||
|
||||
def _swap_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, [to_index,from_index]] = \
|
||||
cache_t[:, [from_index,to_index]]
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.mamba_cache, state_indices_tensor)
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
@ -81,142 +106,53 @@ class MambaCacheManager:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _move_out_if_already_occupied(self, index: int,
|
||||
all_occupied_indices: List[int]):
|
||||
if index in all_occupied_indices:
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
# In case occupied, move the occupied to a new empty block
|
||||
self._move_cache_index_and_mappings(from_index=index,
|
||||
to_index=first_free_index)
|
||||
|
||||
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
|
||||
seq_id: int,
|
||||
destination_index: int):
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.mamba_cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened now we only need to copy the already
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
index_exists = list(seq_ids2indices.values())[0]
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
# already exists
|
||||
cache_index_already_exists = self.mamba_cache_indices_mapping[
|
||||
cur_rid][seq_id]
|
||||
if cache_index_already_exists != destination_index:
|
||||
# In case the seq id already exists but not in
|
||||
# the right destination, swap it with what's occupying it
|
||||
self._swap_pair_indices_and_mappings(
|
||||
from_index=cache_index_already_exists,
|
||||
to_index=destination_index)
|
||||
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]):
|
||||
running_indices = []
|
||||
request_ids_to_seq_ids_flatten = [
|
||||
(req_id, seq_id)
|
||||
finished_requests_ids: List[str]) -> List[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
batch_size = len(request_ids_to_seq_ids_flatten)
|
||||
for dest_index, (request_id,
|
||||
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
||||
if request_id in finished_requests_ids:
|
||||
# Do not allocate cache index for requests that run
|
||||
# and finish right after
|
||||
continue
|
||||
self._assign_seq_id_to_mamba_cache_in_specific_dest(
|
||||
request_id, seq_id, dest_index)
|
||||
running_indices.append(dest_index)
|
||||
|
||||
self._clean_up_first_bs_blocks(batch_size, running_indices)
|
||||
conv_state = self.mamba_cache[0][:, :batch_size]
|
||||
temporal_state = self.mamba_cache[1][:, :batch_size]
|
||||
|
||||
return (conv_state, temporal_state)
|
||||
|
||||
def _get_all_occupied_indices(self):
|
||||
return [
|
||||
cache_idx
|
||||
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
|
||||
for cache_idx in seq_ids2indices.values()
|
||||
]
|
||||
|
||||
def _clean_up_first_bs_blocks(self, batch_size: int,
|
||||
indices_for_current_run: List[int]):
|
||||
# move out all of the occupied but currently not running blocks
|
||||
# outside of the first n blocks
|
||||
destination_indices = range(batch_size)
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
for destination_index in destination_indices:
|
||||
if destination_index in self._get_all_occupied_indices() and \
|
||||
destination_index not in indices_for_current_run:
|
||||
# move not running indices outside of the batch
|
||||
all_other_indices = list(
|
||||
range(batch_size, max_possible_batch_size))
|
||||
first_avail_index = self._first_free_index_in_mamba_cache(
|
||||
all_other_indices)
|
||||
self._swap_indices(from_index=destination_index,
|
||||
to_index=first_avail_index)
|
||||
|
||||
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
|
||||
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._update_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
|
||||
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._swap_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
elif to_index == index:
|
||||
seq_ids2index.update({seq_id: from_index})
|
||||
|
||||
def _update_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
return
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
for seq_id in self.mamba_cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.mamba_cache_indices_mapping[req_id][seq_id])
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
|
||||
def _first_free_index_in_mamba_cache(
|
||||
self, indices_range: Optional[List[int]] = None) -> int:
|
||||
assert self.mamba_cache is not None
|
||||
if indices_range is None:
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
indices_range = list(range(max_possible_batch_size))
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
for i in indices_range:
|
||||
if i not in all_occupied_indices:
|
||||
return i
|
||||
raise Exception("Couldn't find a free spot in the mamba cache! This"
|
||||
"should never happen")
|
||||
|
Loading…
x
Reference in New Issue
Block a user