[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)
This commit is contained in:
parent
8c56e57def
commit
fdd9daafa3
@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
|
23
Dockerfile
23
Dockerfile
@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
RUN python3 -m pip install packaging
|
||||
RUN python3 -m pip install -r requirements-mamba.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-dev.txt
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
FROM dev as mamba-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
WORKDIR /usr/src/mamba
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel -r requirements-mamba.txt \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||
@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
|
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
@ -0,0 +1,700 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template <typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
void* bias_ptr,
|
||||
bool silu_activation) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.x_batch_stride = x.stride(0);
|
||||
params.x_c_stride = x.stride(1);
|
||||
params.x_l_stride = x.stride(-1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_c_stride = out.stride(1);
|
||||
params.out_l_stride = out.stride(-1);
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &seq_idx_,
|
||||
const c10::optional<at::Tensor> &initial_states_,
|
||||
const c10::optional<at::Tensor> &final_states_out_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
||||
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
||||
|
||||
if (is_channel_last) {
|
||||
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
||||
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
||||
}
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
||||
auto seq_idx = seq_idx_.value();
|
||||
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(seq_idx.is_cuda());
|
||||
TORCH_CHECK(seq_idx.is_contiguous());
|
||||
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
||||
} else {
|
||||
params.seq_idx_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (initial_states_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
||||
auto initial_states = initial_states_.value();
|
||||
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(initial_states.is_cuda());
|
||||
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(initial_states.stride(1) == 1);
|
||||
params.initial_states_ptr = initial_states.data_ptr();
|
||||
params.initial_states_batch_stride = initial_states.stride(0);
|
||||
params.initial_states_c_stride = initial_states.stride(1);
|
||||
params.initial_states_l_stride = initial_states.stride(2);
|
||||
} else {
|
||||
params.initial_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (final_states_out_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
||||
auto final_states = final_states_out_.value();
|
||||
TORCH_CHECK(final_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(final_states.is_cuda());
|
||||
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(final_states.stride(1) == 1);
|
||||
params.final_states_ptr = final_states.data_ptr();
|
||||
params.final_states_batch_stride = final_states.stride(0);
|
||||
params.final_states_c_stride = final_states.stride(1);
|
||||
params.final_states_l_stride = final_states.stride(2);
|
||||
} else {
|
||||
params.final_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
if (!is_channel_last) {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
}
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim);
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t zeros[kNElts] = {0};
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifndef USE_ROCM
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
||||
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
||||
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
||||
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
||||
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static_assert(kNThreads % 32 == 0);
|
||||
static constexpr int kNWarps = kNThreads / 32;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kChunkSizeL = kChunkSizeL_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
||||
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
||||
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
||||
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
||||
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
||||
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
||||
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
||||
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
||||
// sizeof(typename BlockStoreT::TempStorage)});
|
||||
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kHasSeqIdx>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
||||
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int chunk_l_id = blockIdx.y;
|
||||
const int chunk_c_id = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int l_idx = tid / kNThreadsPerC;
|
||||
const int c_idx = tid % kNThreadsPerC;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
||||
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
||||
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
||||
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
||||
// from the previous L-chunk.
|
||||
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
// Load the elements from the previous chunk that are needed for convolution.
|
||||
if (l_idx < kWidth - 1) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
||||
} else if (initial_states != nullptr
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (final_states != nullptr
|
||||
&& l_idx < kWidth - 1
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
||||
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
||||
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
||||
}
|
||||
|
||||
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
||||
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
||||
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
||||
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
||||
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
||||
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
||||
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
||||
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
||||
static_assert(kNThreadsPerRow <= 32);
|
||||
|
||||
const int row_idx = tid / kNThreadsPerRow;
|
||||
const int col_idx = tid % kNThreadsPerRow;
|
||||
|
||||
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
||||
}
|
||||
}
|
||||
float x_vals[kWidth - 1 + kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
||||
}
|
||||
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
||||
if constexpr (kHasSeqIdx) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
float out_vals[kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
if constexpr (!kHasSeqIdx) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
||||
} else {
|
||||
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
||||
}
|
||||
}
|
||||
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t out_vals_store[kNElts];
|
||||
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
||||
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
||||
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
||||
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
||||
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
||||
dim3 block(Ktraits::kNThreads);
|
||||
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
||||
// if (kSmemSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
// }
|
||||
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
///////
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
}
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
||||
x_vals[kWidth - 1] = float(x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
||||
}
|
||||
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
@ -0,0 +1,144 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
@ -0,0 +1,276 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
bool delta_softplus;
|
||||
|
||||
index_t A_d_stride;
|
||||
index_t A_dstate_stride;
|
||||
index_t B_batch_stride;
|
||||
index_t B_d_stride;
|
||||
index_t B_dstate_stride;
|
||||
index_t B_group_stride;
|
||||
index_t C_batch_stride;
|
||||
index_t C_d_stride;
|
||||
index_t C_dstate_stride;
|
||||
index_t C_group_stride;
|
||||
index_t u_batch_stride;
|
||||
index_t u_d_stride;
|
||||
index_t delta_batch_stride;
|
||||
index_t delta_d_stride;
|
||||
index_t z_batch_stride;
|
||||
index_t z_d_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
void *__restrict__ B_ptr;
|
||||
void *__restrict__ C_ptr;
|
||||
void *__restrict__ D_ptr;
|
||||
void *__restrict__ u_ptr;
|
||||
void *__restrict__ delta_ptr;
|
||||
void *__restrict__ delta_bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ z_ptr;
|
||||
void *__restrict__ out_z_ptr;
|
||||
void *__restrict__ index_ptr;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define MAX_DSTATE 256
|
||||
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename scalar_t, int N>
|
||||
struct Converter{
|
||||
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
||||
}
|
||||
};
|
||||
|
||||
template<int N>
|
||||
struct Converter<at::Half, N>{
|
||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
template<int N>
|
||||
struct Converter<at::BFloat16, N>{
|
||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template<typename scalar_t> struct SSMScanOp;
|
||||
|
||||
template<>
|
||||
struct SSMScanOp<float> {
|
||||
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
||||
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
||||
}
|
||||
};
|
||||
|
||||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
||||
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
||||
scan_t running_prefix;
|
||||
// Constructor
|
||||
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__device__ scan_t operator()(scan_t block_aggregate) {
|
||||
scan_t old_prefix = running_prefix;
|
||||
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_input(typename Ktraits::input_t *u,
|
||||
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
||||
reinterpret_cast<vec_t*>(u),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
||||
#ifdef USE_ROCM
|
||||
, Ktraits::kNThreads * Ktraits::kNLoads
|
||||
#endif
|
||||
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_index(int *u,
|
||||
int (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
|
||||
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
|
||||
reinterpret_cast<uint4*>(u),
|
||||
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
|
||||
);
|
||||
} else {
|
||||
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
||||
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
||||
int seqlen) {
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
typename Ktraits::input_t B_vals_load[kNItems];
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||
reinterpret_cast<vec_t*>(Bvar),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||
}
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
||||
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void store_output(typename Ktraits::input_t *out,
|
||||
const float (&out_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
||||
int seqlen) {
|
||||
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
||||
reinterpret_cast<vec_t*>(out),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
||||
}
|
||||
}
|
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
@ -0,0 +1,593 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "selective_scan.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||
bool kIsVariableB_, bool kIsVariableC_,
|
||||
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
|
||||
struct Selective_Scan_fwd_kernel_traits {
|
||||
static_assert(kNItems_ % 4 == 0);
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||
static constexpr int kNItems = kNItems_;
|
||||
static constexpr int kNRows = kNRows_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||
static_assert(kNItems % kNElts == 0);
|
||||
static constexpr int kNLoads = kNItems / kNElts;
|
||||
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||
static constexpr bool kHasZ = kHasZ_;
|
||||
static constexpr bool kUseIndex = kUseIndex_;
|
||||
|
||||
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
||||
static constexpr int kNLoadsIndex = kNItems / 4;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using scan_t = float2;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
|
||||
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||
sizeof(typename BlockLoadVecT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexVecT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||
sizeof(typename BlockStoreT::TempStorage),
|
||||
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr bool kUseIndex = Ktraits::kUseIndex;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
constexpr int kNRows = Ktraits::kNRows;
|
||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
using scan_t = typename Ktraits::scan_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// cast to lvalue reference of expected type
|
||||
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
||||
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
||||
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int dim_id = blockIdx.y;
|
||||
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||
+ dim_id * kNRows * params.delta_d_stride;
|
||||
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
||||
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
||||
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
float delta_bias[kNRows] = {0};
|
||||
if (params.delta_bias_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
int index_vals_load[kNRows][kNItems];
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (kUseIndex) {
|
||||
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
if constexpr (kUseIndex) {
|
||||
index += kChunkSize;
|
||||
}
|
||||
u += kChunkSize;
|
||||
delta += kChunkSize;
|
||||
|
||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float u_val = float(u_vals[r][i]);
|
||||
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
||||
if (params.delta_softplus) {
|
||||
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
||||
}
|
||||
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
||||
out_vals[r][i] = D_val[r] * u_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||
weight_t A_val[kNRows];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
||||
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
A_val[r] *= kLog2e;
|
||||
}
|
||||
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
||||
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
||||
// If both B and C vary, this is unused.
|
||||
weight_t BC_val[kNRows];
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (kIsVariableB) {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
|
||||
if constexpr (!kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB && !kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
||||
scan_t thread_data[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
|
||||
// Reset A bar for cumulative sequences (Real)
|
||||
if constexpr (kUseIndex) {
|
||||
if (index_vals_load[r][i] == 0) {
|
||||
thread_data[i].x = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
scan_t running_prefix;
|
||||
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
||||
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
|
||||
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const weight_t C_val = !kIsVariableC
|
||||
? BC_val[r]
|
||||
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
||||
out_vals[r][i] += thread_data[i].y * C_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
input_t z_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
|
||||
Bvar += kChunkSize * 1;
|
||||
Cvar += kChunkSize * 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||
// processing 1 row.
|
||||
constexpr int kNRows = 1;
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
if (params.seqlen <= 128) {
|
||||
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const torch::Tensor u,
|
||||
const torch::Tensor delta,
|
||||
const torch::Tensor A,
|
||||
const torch::Tensor B,
|
||||
const torch::Tensor C,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
void* D_ptr,
|
||||
void* delta_bias_ptr,
|
||||
void* x_ptr,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
void* index_ptr) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
params.is_variable_B = is_variable_B;
|
||||
params.is_variable_C = is_variable_C;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.u_ptr = u.data_ptr();
|
||||
params.delta_ptr = delta.data_ptr();
|
||||
params.A_ptr = A.data_ptr();
|
||||
params.B_ptr = B.data_ptr();
|
||||
params.C_ptr = C.data_ptr();
|
||||
params.D_ptr = D_ptr;
|
||||
params.delta_bias_ptr = delta_bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
params.x_ptr = x_ptr;
|
||||
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
||||
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
||||
|
||||
params.index_ptr = index_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
params.A_dstate_stride = A.stride(1);
|
||||
if (!is_variable_B) {
|
||||
params.B_d_stride = B.stride(0);
|
||||
} else {
|
||||
params.B_batch_stride = B.stride(0);
|
||||
params.B_group_stride = B.stride(1);
|
||||
}
|
||||
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
||||
if (!is_variable_C) {
|
||||
params.C_d_stride = C.stride(0);
|
||||
} else {
|
||||
params.C_batch_stride = C.stride(0);
|
||||
params.C_group_stride = C.stride(1);
|
||||
}
|
||||
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
||||
params.u_batch_stride = u.stride(0);
|
||||
params.u_d_stride = u.stride(1);
|
||||
params.delta_batch_stride = delta.stride(0);
|
||||
params.delta_d_stride = delta.stride(1);
|
||||
if (has_z) {
|
||||
params.z_batch_stride = z.stride(0);
|
||||
params.z_d_stride = z.stride(1);
|
||||
params.out_z_batch_stride = out_z.stride(0);
|
||||
params.out_z_d_stride = out_z.stride(1);
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const c10::optional<torch::Tensor> &D_,
|
||||
const c10::optional<torch::Tensor> &z_,
|
||||
const c10::optional<torch::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor> &index_,
|
||||
const c10::optional<torch::Tensor> &x) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
|
||||
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
if (index_.has_value()) {
|
||||
auto index = index_.value();
|
||||
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(index.is_cuda());
|
||||
CHECK_SHAPE(index, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
out_z = torch::empty_like(z);
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = torch::empty_like(delta);
|
||||
if (x.has_value()){
|
||||
auto _x = x.value();
|
||||
TORCH_CHECK(_x.scalar_type() == weight_type);
|
||||
TORCH_CHECK(_x.is_cuda());
|
||||
TORCH_CHECK(_x.stride(-1) == 1);
|
||||
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
|
||||
}
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||
x.value().data_ptr(),
|
||||
has_z,
|
||||
delta_softplus,
|
||||
index_.has_value() ? index_.value().data_ptr() : nullptr);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
std::vector<at::Tensor> result = {out, x.value()};
|
||||
if (has_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
22
csrc/ops.h
22
csrc/ops.h
@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
std::vector<torch::Tensor> selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
const c10::optional<torch::Tensor>& D_,
|
||||
const c10::optional<torch::Tensor>& z_,
|
||||
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||
const c10::optional<torch::Tensor>& index_,
|
||||
const c10::optional<torch::Tensor>& x);
|
||||
|
||||
at::Tensor causal_conv1d_update(const at::Tensor& x,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
bool silu_activation);
|
||||
|
||||
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& seq_idx_,
|
||||
const c10::optional<at::Tensor>& initial_states_,
|
||||
const c10::optional<at::Tensor>& final_states_out_,
|
||||
bool silu_activation);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
|
@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? index_, Tensor? x) -> Tensor[]");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor? seq_idx_,"
|
||||
"Tensor? initial_states_,"
|
||||
"Tensor? final_states_out_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
|
@ -1,3 +0,0 @@
|
||||
# Mamba dependencies
|
||||
mamba-ssm>=1.2.2
|
||||
causal-conv1d>=1.2.0
|
@ -11,7 +11,7 @@ pytest-shard
|
||||
|
||||
# testing utils
|
||||
awscli
|
||||
einops # required for MPT and qwen-vl
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
peft
|
||||
requests
|
||||
|
205
tests/kernels/test_causal_conv1d.py
Normal file
205
tests/kernels/test_causal_conv1d.py
Normal file
@ -0,0 +1,205 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None):
|
||||
"""
|
||||
x: (batch, dim)
|
||||
conv_state: (batch, dim, width)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
|
||||
out: (batch, dim)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
batch, dim = x.shape
|
||||
width = weight.shape[1]
|
||||
assert conv_state.shape == (batch, dim, width)
|
||||
assert weight.shape == (dim, width)
|
||||
conv_state.copy_(torch.roll(conv_state, shifts=-1,
|
||||
dims=-1)) # Update state (B D W)
|
||||
conv_state[:, :, -1] = x
|
||||
out = torch.sum(conv_state * weight, dim=-1) # (B D)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("return_final_states", [False, True])
|
||||
@pytest.mark.parametrize("has_initial_states", [False, True])
|
||||
@pytest.mark.parametrize("channel_last", [False, True])
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("seqlen", [128, 512, 4096])
|
||||
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
||||
@pytest.mark.parametrize('batch', [1, 2])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
itype, channel_last, has_initial_states,
|
||||
return_final_states):
|
||||
if not channel_last and (has_initial_states or return_final_states):
|
||||
pytest.skip(
|
||||
"Only channel_last support initial_states or return_final_states")
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
if not channel_last:
|
||||
x = torch.randn(batch,
|
||||
4096 + dim + 64,
|
||||
seqlen,
|
||||
device=device,
|
||||
dtype=itype)[:, 4096:4096 + dim, :]
|
||||
else:
|
||||
x = rearrange(
|
||||
torch.randn(batch,
|
||||
seqlen,
|
||||
4096 + dim + 64,
|
||||
device=device,
|
||||
dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
if has_initial_states:
|
||||
initial_states = torch.randn(batch,
|
||||
width - 1,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=itype).transpose(1, 2)
|
||||
else:
|
||||
initial_states = None
|
||||
x_ref = x.detach().clone()
|
||||
weight_ref = weight.detach().clone()
|
||||
bias_ref = bias.detach().clone() if bias is not None else None
|
||||
initial_states_ref = initial_states.detach().clone(
|
||||
) if initial_states is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
out, final_states = causal_conv1d_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
initial_states=initial_states,
|
||||
return_final_states=return_final_states,
|
||||
activation=activation)
|
||||
out_ref, final_states_ref = causal_conv1d_ref(
|
||||
x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=return_final_states,
|
||||
activation=activation)
|
||||
if return_final_states:
|
||||
assert final_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(final_states,
|
||||
final_states_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
if return_final_states:
|
||||
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
|
||||
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
@pytest.mark.parametrize("batch", [1, 2])
|
||||
def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, device=device, dtype=itype)
|
||||
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
|
||||
weight = torch.randn(dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype,
|
||||
requires_grad=True)
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
bias = None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
out_ref = causal_conv1d_update_ref(x,
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
324
tests/kernels/test_mamba_ssm.py
Normal file
324
tests/kernels/test_mamba_ssm.py
Normal file
@ -0,0 +1,324 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = F.softplus(dt) if dt_softplus else dt
|
||||
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
|
||||
A) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
||||
B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
||||
state.copy_(state * dA +
|
||||
dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
||||
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).to(out.dtype)
|
||||
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_ref(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
position_indices=None,
|
||||
prev_state=None):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
prev_state: r(B D N), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
if position_indices is not None and position_indices[0, i] == 0:
|
||||
x = deltaB_u[:, :, i]
|
||||
else:
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum('bdn,dn->bd', x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("return_last_state", [True])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
|
||||
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
has_z, has_delta_bias, delta_softplus,
|
||||
return_last_state, seqlen, itype, wtype, scan_chunks):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
rtolw, atolw = (1e-3, 1e-3)
|
||||
if has_z: # If we have z, the errors on the weights seem higher
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
dim = 4
|
||||
dstate = 8
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
|
||||
if not is_variable_B:
|
||||
B_shape = [dim, dstate]
|
||||
elif varBC_groups == 1:
|
||||
B_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
B_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
B = torch.randn(B_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_B else itype)
|
||||
if not is_variable_C:
|
||||
C_shape = [dim, dstate]
|
||||
elif varBC_groups == 1:
|
||||
C_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
C_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
C = torch.randn(C_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_C else itype)
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
z = torch.randn(batch_size, dim, seqlen, device=device,
|
||||
dtype=itype) if has_z else None
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
|
||||
) if has_delta_bias else None
|
||||
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
delta = (0.5 *
|
||||
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
|
||||
state = None
|
||||
state_ref = None
|
||||
out = None
|
||||
out_ref = None
|
||||
outs = []
|
||||
for c in range(scan_chunks):
|
||||
chunked_prompt_len = seqlen // scan_chunks
|
||||
chunk_start = chunked_prompt_len * c
|
||||
chunk_end = chunked_prompt_len * (c + 1)
|
||||
if c == scan_chunks - 1:
|
||||
chunk_end = seqlen
|
||||
_B = B
|
||||
if is_variable_B:
|
||||
_B = B[..., chunk_start:chunk_end]
|
||||
_C = C
|
||||
if is_variable_B:
|
||||
_C = C[..., chunk_start:chunk_end]
|
||||
_z = z
|
||||
if has_z:
|
||||
assert z is not None
|
||||
_z = z[..., chunk_start:chunk_end]
|
||||
out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end],
|
||||
delta[..., chunk_start:chunk_end],
|
||||
A,
|
||||
_B,
|
||||
_C,
|
||||
D,
|
||||
z=_z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state,
|
||||
prev_state=state if c > 0 else None)
|
||||
outs.append(out)
|
||||
if return_last_state:
|
||||
state = rest[0]
|
||||
if len(outs) > 1:
|
||||
out = torch.cat(outs, dim=-1)
|
||||
out_ref, *rest = selective_scan_ref(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z=z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state)
|
||||
if return_last_state:
|
||||
state_ref = rest[0]
|
||||
|
||||
assert out is not None and out_ref is not None
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
if return_last_state:
|
||||
assert state is not None and state_ref is not None
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
@ -500,6 +500,36 @@ def ggml_mul_mat_a8(
|
||||
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
||||
|
||||
|
||||
# mamba
|
||||
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
seq_idx_: Optional[torch.Tensor],
|
||||
initial_states_: Optional[torch.Tensor],
|
||||
final_states_out_: Optional[torch.Tensor],
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
|
||||
initial_states_, final_states_out_,
|
||||
silu_activation)
|
||||
|
||||
|
||||
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
||||
silu_activation)
|
||||
|
||||
|
||||
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||
B: torch.Tensor, C: torch.Tensor,
|
||||
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
||||
delta_bias_: Optional[torch.Tensor],
|
||||
delta_softplus: bool, index_: Optional[torch.Tensor],
|
||||
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
|
||||
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
|
||||
delta_bias_, delta_softplus, index_,
|
||||
x)
|
||||
|
||||
|
||||
# moe
|
||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
block_size: int, sorted_token_ids: torch.Tensor,
|
||||
|
0
vllm/model_executor/layers/mamba/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
86
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
86
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out=None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
seq_idx: (batch, seqlen)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1), to be written to
|
||||
activation: either None or "silu" or "swish"
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(2) != 1 and x.stride(1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
if seq_idx is not None:
|
||||
assert (initial_states is
|
||||
None), "initial_states must be None if seq_idx is not None"
|
||||
assert (not return_final_states
|
||||
), "If seq_idx is not None, we don't return final_states_out"
|
||||
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
||||
if initial_states is not None and (initial_states.stride(2) != 1
|
||||
and initial_states.stride(1) != 1):
|
||||
initial_states = initial_states.contiguous()
|
||||
if return_final_states:
|
||||
assert (
|
||||
x.stride(1) == 1
|
||||
), "Only channel-last layout support returning final_states_out"
|
||||
if final_states_out is not None:
|
||||
assert (final_states_out.stride(2) == 1
|
||||
or final_states_out.stride(1) == 1)
|
||||
else:
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
final_states_out = torch.empty(batch,
|
||||
width - 1,
|
||||
dim,
|
||||
device=x.device,
|
||||
dtype=x.dtype).transpose(1, 2)
|
||||
else:
|
||||
final_states_out = None
|
||||
|
||||
out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
|
||||
final_states_out, activation
|
||||
in ["silu", "swish"])
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update(x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None):
|
||||
"""
|
||||
x: (batch, dim)
|
||||
conv_state: (batch, dim, width)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
|
||||
out: (batch, dim)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
activation_bool = activation in ["silu", "swish"]
|
||||
return ops.causal_conv1d_update(x, conv_state, weight, bias,
|
||||
activation_bool)
|
346
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
346
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
@ -0,0 +1,346 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
@triton.jit
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
A_ptr,
|
||||
B_ptr,
|
||||
C_ptr,
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch,
|
||||
stride_state_head,
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_bias_dim,
|
||||
stride_A_head,
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
|
||||
offs_n[None, :] * stride_A_dstate)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
state = tl.load(state_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0)
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
tl.store(state_ptrs,
|
||||
state,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
out = torch.empty_like(x)
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
|
||||
((16, 4) if dstate <= 32 else
|
||||
((8, 4) if dstate <= 64 else
|
||||
((4, 4) if dstate <= 128 else ((4, 8))))))
|
||||
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
|
||||
-1) == 0 and dt_bias.stride(-1) == 0
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads // ngroups,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0),
|
||||
dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
position_indices=None,
|
||||
prev_state=None):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 3:
|
||||
C = C.unsqueeze(1)
|
||||
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
|
||||
x = torch.zeros((
|
||||
u.shape[0],
|
||||
u.shape[1],
|
||||
n_chunks,
|
||||
int(A.shape[1] * 2),
|
||||
),
|
||||
device=u.device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=False)
|
||||
x[:, :, 0, 0::2] = 1
|
||||
if prev_state is not None:
|
||||
x[:, :, 0, 1::2].copy_(prev_state)
|
||||
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, position_indices, x)
|
||||
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
||||
if z is None:
|
||||
return out if not return_last_state else (out, last_state)
|
||||
else:
|
||||
out_z = rest[0]
|
||||
return out_z if not return_last_state else (out_z, last_state)
|
@ -4,9 +4,6 @@ from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import JambaConfig
|
||||
@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module):
|
||||
(self.conv_kernel_size - hidden_states.shape[-1], 0))
|
||||
cache_params.conv_state.copy_(conv_states)
|
||||
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states, _ = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
|
Loading…
x
Reference in New Issue
Block a user