[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012)
This commit is contained in:
parent
09deb4721f
commit
8110e44529
@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation) {
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor> &conv_state_indices_) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim);
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
|
@ -36,6 +36,10 @@ struct ConvParamsBase {
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
|
@ -222,11 +222,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
|
||||
const c10::optional<torch::Tensor>& index_,
|
||||
const c10::optional<torch::Tensor>& x);
|
||||
|
||||
at::Tensor causal_conv1d_update(const at::Tensor& x,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
bool silu_activation);
|
||||
at::Tensor causal_conv1d_update(
|
||||
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias, bool silu_activation,
|
||||
const c10::optional<at::Tensor>& conv_state_indices);
|
||||
|
||||
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
|
@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
"Tensor? bias,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? conv_state_indices) -> Tensor");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
|
@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch = 64
|
||||
|
||||
x = torch.randn(batch, dim, device=device, dtype=itype)
|
||||
|
||||
total_entries = 10 * batch
|
||||
conv_state = torch.randn(total_entries,
|
||||
dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch].to(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
weight = torch.randn(dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype,
|
||||
requires_grad=True)
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
bias = None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=conv_state_indices)
|
||||
out_ref = causal_conv1d_update_ref(x,
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
||||
silu_activation)
|
||||
|
||||
|
||||
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
||||
silu_activation)
|
||||
silu_activation,
|
||||
conv_state_indices)
|
||||
|
||||
|
||||
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None):
|
||||
activation: Optional[str] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
x: (batch, dim)
|
||||
conv_state: (batch, dim, width)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
|
||||
out: (batch, dim)
|
||||
"""
|
||||
@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor,
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
activation_bool = activation in ["silu", "swish"]
|
||||
return ops.causal_conv1d_update(x, conv_state, weight, bias,
|
||||
activation_bool)
|
||||
activation_bool, conv_state_indices)
|
||||
|
Loading…
x
Reference in New Issue
Block a user