[MISC] Replace c10::optional with std::optional (#11730)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
47831430cc
commit
4068f4b5b5
@ -53,7 +53,7 @@ void paged_attention_v1_launcher(
|
|||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_head_sliding_step) {
|
||||||
@ -176,7 +176,7 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
@ -54,7 +54,7 @@ void paged_attention_v2_launcher(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_head_sliding_step) {
|
||||||
@ -187,7 +187,7 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher(
|
|||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const std::optional<torch::Tensor>& alibi_slopes) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -459,7 +459,7 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
|
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -781,7 +781,7 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
|||||||
const torch::Tensor& b, // [IC, OC], column-major
|
const torch::Tensor& b, // [IC, OC], column-major
|
||||||
const torch::Tensor& a_scales, // [1] or [M]
|
const torch::Tensor& a_scales, // [1] or [M]
|
||||||
const torch::Tensor& b_scales, // [1] or [OC]
|
const torch::Tensor& b_scales, // [1] or [OC]
|
||||||
const c10::optional<torch::Tensor>& bias // [OC]
|
const std::optional<torch::Tensor>& bias // [OC]
|
||||||
) {
|
) {
|
||||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
|||||||
const torch::Tensor& a_scales, // [1] or [M]
|
const torch::Tensor& a_scales, // [1] or [M]
|
||||||
const torch::Tensor& b_scales, // [1] or [OC]
|
const torch::Tensor& b_scales, // [1] or [OC]
|
||||||
const torch::Tensor& azp_adj, // [OC]
|
const torch::Tensor& azp_adj, // [OC]
|
||||||
const c10::optional<torch::Tensor>& azp, // [1] or [M]
|
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||||
const c10::optional<torch::Tensor>& bias // [OC]
|
const std::optional<torch::Tensor>& bias // [OC]
|
||||||
) {
|
) {
|
||||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
|||||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
const torch::Tensor& input, // [..., hidden_size]
|
const torch::Tensor& input, // [..., hidden_size]
|
||||||
const torch::Tensor& scale,
|
const torch::Tensor& scale,
|
||||||
c10::optional<torch::Tensor> const& azp) {
|
std::optional<torch::Tensor> const& azp) {
|
||||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant(
|
|||||||
torch::Tensor& out, // [..., hidden_size]
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
const torch::Tensor& input, // [..., hidden_size]
|
const torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& scale, // [..., 1]
|
torch::Tensor& scale, // [..., 1]
|
||||||
c10::optional<torch::Tensor> const& azp) {
|
std::optional<torch::Tensor> const& azp) {
|
||||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids);
|
|||||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||||
const torch::Tensor& b_scales,
|
const torch::Tensor& b_scales,
|
||||||
const c10::optional<torch::Tensor>& bias);
|
const std::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||||
const torch::Tensor& b_scales,
|
const torch::Tensor& b_scales,
|
||||||
const torch::Tensor& azp_adj,
|
const torch::Tensor& azp_adj,
|
||||||
const c10::optional<torch::Tensor>& azp,
|
const std::optional<torch::Tensor>& azp,
|
||||||
const c10::optional<torch::Tensor>& bias);
|
const std::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
// vLLM custom ops
|
// vLLM custom ops
|
||||||
|
@ -68,7 +68,7 @@ struct ScaledEpilogueBase {
|
|||||||
// This overload handles the case where there might not be a tensor, in which
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
// case a nullptr is passed and a constant (0) is used.
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
template <typename Descriptor, typename T>
|
template <typename Descriptor, typename T>
|
||||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||||
using Arguments = typename Descriptor::Arguments;
|
using Arguments = typename Descriptor::Arguments;
|
||||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
|||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@ -301,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
torch::Tensor const& azp,
|
torch::Tensor const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
@ -67,7 +67,7 @@ struct ScaledEpilogueBase {
|
|||||||
// This overload handles the case where there might not be a tensor, in which
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
// case a nullptr is passed and a constant (0) is used.
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
template <typename Descriptor, typename T>
|
template <typename Descriptor, typename T>
|
||||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||||
using Arguments = typename Descriptor::Arguments;
|
using Arguments = typename Descriptor::Arguments;
|
||||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||||
@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
|||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@ -299,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
torch::Tensor const& azp,
|
torch::Tensor const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
|||||||
|
|
||||||
template <typename Stride>
|
template <typename Stride>
|
||||||
static inline auto maybe_make_cute_layout(
|
static inline auto maybe_make_cute_layout(
|
||||||
c10::optional<torch::Tensor> const& tensor,
|
std::optional<torch::Tensor> const& tensor,
|
||||||
std::string_view name = "tensor") {
|
std::string_view name = "tensor") {
|
||||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||||
|
|
||||||
|
@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
|||||||
const at::Tensor x,
|
const at::Tensor x,
|
||||||
const at::Tensor weight,
|
const at::Tensor weight,
|
||||||
const at::Tensor out,
|
const at::Tensor out,
|
||||||
const c10::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
bool silu_activation,
|
bool silu_activation,
|
||||||
int64_t pad_slot_id,
|
int64_t pad_slot_id,
|
||||||
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
|
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||||
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
|
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||||
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||||
|
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
memset(¶ms, 0, sizeof(params));
|
memset(¶ms, 0, sizeof(params));
|
||||||
@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
|||||||
|
|
||||||
|
|
||||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||||
const c10::optional<at::Tensor> &bias_,
|
const std::optional<at::Tensor> &bias_,
|
||||||
const c10::optional<at::Tensor> &conv_states,
|
const std::optional<at::Tensor> &conv_states,
|
||||||
const c10::optional<at::Tensor> &query_start_loc,
|
const std::optional<at::Tensor> &query_start_loc,
|
||||||
const c10::optional<at::Tensor> &cache_indices,
|
const std::optional<at::Tensor> &cache_indices,
|
||||||
const c10::optional<at::Tensor> &has_initial_state,
|
const std::optional<at::Tensor> &has_initial_state,
|
||||||
bool silu_activation,
|
bool silu_activation,
|
||||||
// used to identify padding entries if cache_indices provided
|
// used to identify padding entries if cache_indices provided
|
||||||
// in case of padding, the kernel will return early
|
// in case of padding, the kernel will return early
|
||||||
@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
|||||||
void causal_conv1d_update(const at::Tensor &x,
|
void causal_conv1d_update(const at::Tensor &x,
|
||||||
const at::Tensor &conv_state,
|
const at::Tensor &conv_state,
|
||||||
const at::Tensor &weight,
|
const at::Tensor &weight,
|
||||||
const c10::optional<at::Tensor> &bias_,
|
const std::optional<at::Tensor> &bias_,
|
||||||
bool silu_activation,
|
bool silu_activation,
|
||||||
const c10::optional<at::Tensor> &cache_seqlens_,
|
const std::optional<at::Tensor> &cache_seqlens_,
|
||||||
const c10::optional<at::Tensor> &conv_state_indices_,
|
const std::optional<at::Tensor> &conv_state_indices_,
|
||||||
// used to identify padding entries if cache_indices provided
|
// used to identify padding entries if cache_indices provided
|
||||||
// in case of padding, the kernel will return early
|
// in case of padding, the kernel will return early
|
||||||
int64_t pad_slot_id) {
|
int64_t pad_slot_id) {
|
||||||
|
@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
const torch::Tensor out,
|
const torch::Tensor out,
|
||||||
const torch::Tensor z,
|
const torch::Tensor z,
|
||||||
const torch::Tensor out_z,
|
const torch::Tensor out_z,
|
||||||
const c10::optional<at::Tensor>& D,
|
const std::optional<at::Tensor>& D,
|
||||||
const c10::optional<at::Tensor>& delta_bias,
|
const std::optional<at::Tensor>& delta_bias,
|
||||||
const torch::Tensor ssm_states,
|
const torch::Tensor ssm_states,
|
||||||
bool has_z,
|
bool has_z,
|
||||||
bool delta_softplus,
|
bool delta_softplus,
|
||||||
const c10::optional<at::Tensor>& query_start_loc,
|
const std::optional<at::Tensor>& query_start_loc,
|
||||||
const c10::optional<at::Tensor>& cache_indices,
|
const std::optional<at::Tensor>& cache_indices,
|
||||||
const c10::optional<at::Tensor>& has_initial_state,
|
const std::optional<at::Tensor>& has_initial_state,
|
||||||
bool varlen,
|
bool varlen,
|
||||||
int64_t pad_slot_id) {
|
int64_t pad_slot_id) {
|
||||||
|
|
||||||
@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
|
|
||||||
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||||
const c10::optional<torch::Tensor> &D_,
|
const std::optional<torch::Tensor> &D_,
|
||||||
const c10::optional<torch::Tensor> &z_,
|
const std::optional<torch::Tensor> &z_,
|
||||||
const c10::optional<torch::Tensor> &delta_bias_,
|
const std::optional<torch::Tensor> &delta_bias_,
|
||||||
bool delta_softplus,
|
bool delta_softplus,
|
||||||
const c10::optional<torch::Tensor> &query_start_loc,
|
const std::optional<torch::Tensor> &query_start_loc,
|
||||||
const c10::optional<torch::Tensor> &cache_indices,
|
const std::optional<torch::Tensor> &cache_indices,
|
||||||
const c10::optional<torch::Tensor> &has_initial_state,
|
const std::optional<torch::Tensor> &has_initial_state,
|
||||||
const torch::Tensor &ssm_states,
|
const torch::Tensor &ssm_states,
|
||||||
// used to identify padding entries if cache_indices provided
|
// used to identify padding entries if cache_indices provided
|
||||||
// in case of padding, the kernel will return early
|
// in case of padding, the kernel will return early
|
||||||
|
46
csrc/ops.h
46
csrc/ops.h
@ -33,7 +33,7 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
@ -44,7 +44,7 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
@ -153,15 +153,15 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
|||||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b, torch::Tensor const& e,
|
torch::Tensor const& b, torch::Tensor const& e,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||||
torch::Tensor& e, torch::Tensor const& a);
|
torch::Tensor& e, torch::Tensor const& a);
|
||||||
@ -177,11 +177,11 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
|||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor const& scale,
|
torch::Tensor const& scale,
|
||||||
c10::optional<torch::Tensor> const& azp);
|
std::optional<torch::Tensor> const& azp);
|
||||||
|
|
||||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor& scales,
|
torch::Tensor& scales,
|
||||||
c10::optional<torch::Tensor> const& azp);
|
std::optional<torch::Tensor> const& azp);
|
||||||
|
|
||||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
@ -198,34 +198,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
|||||||
|
|
||||||
void dynamic_per_token_scaled_fp8_quant(
|
void dynamic_per_token_scaled_fp8_quant(
|
||||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||||
c10::optional<torch::Tensor> const& scale_ub);
|
std::optional<torch::Tensor> const& scale_ub);
|
||||||
|
|
||||||
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||||
const torch::Tensor& A, const torch::Tensor& B,
|
const torch::Tensor& A, const torch::Tensor& B,
|
||||||
const torch::Tensor& C,
|
const torch::Tensor& C,
|
||||||
const c10::optional<torch::Tensor>& D_,
|
const std::optional<torch::Tensor>& D_,
|
||||||
const c10::optional<torch::Tensor>& z_,
|
const std::optional<torch::Tensor>& z_,
|
||||||
const c10::optional<torch::Tensor>& delta_bias_,
|
const std::optional<torch::Tensor>& delta_bias_,
|
||||||
bool delta_softplus,
|
bool delta_softplus,
|
||||||
const c10::optional<torch::Tensor>& query_start_loc,
|
const std::optional<torch::Tensor>& query_start_loc,
|
||||||
const c10::optional<torch::Tensor>& cache_indices,
|
const std::optional<torch::Tensor>& cache_indices,
|
||||||
const c10::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||||
|
|
||||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||||
const at::Tensor& weight,
|
const at::Tensor& weight,
|
||||||
const c10::optional<at::Tensor>& bias_,
|
const std::optional<at::Tensor>& bias_,
|
||||||
bool silu_activation,
|
bool silu_activation,
|
||||||
const c10::optional<at::Tensor>& cache_seqlens_,
|
const std::optional<at::Tensor>& cache_seqlens_,
|
||||||
const c10::optional<at::Tensor>& conv_state_indices_,
|
const std::optional<at::Tensor>& conv_state_indices_,
|
||||||
int64_t pad_slot_id);
|
int64_t pad_slot_id);
|
||||||
|
|
||||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||||
const c10::optional<at::Tensor>& bias_,
|
const std::optional<at::Tensor>& bias_,
|
||||||
const c10::optional<at::Tensor>& conv_states,
|
const std::optional<at::Tensor>& conv_states,
|
||||||
const c10::optional<at::Tensor>& query_start_loc,
|
const std::optional<at::Tensor>& query_start_loc,
|
||||||
const c10::optional<at::Tensor>& cache_indices,
|
const std::optional<at::Tensor>& cache_indices,
|
||||||
const c10::optional<at::Tensor>& has_initial_state,
|
const std::optional<at::Tensor>& has_initial_state,
|
||||||
bool silu_activation, int64_t pad_slot_id);
|
bool silu_activation, int64_t pad_slot_id);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|||||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor const& input, // [..., hidden_size]
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
torch::Tensor const& scale,
|
torch::Tensor const& scale,
|
||||||
c10::optional<torch::Tensor> const& azp) {
|
std::optional<torch::Tensor> const& azp) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
TORCH_CHECK(scale.numel() == 1);
|
TORCH_CHECK(scale.numel() == 1);
|
||||||
@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
void dynamic_scaled_int8_quant(
|
void dynamic_scaled_int8_quant(
|
||||||
torch::Tensor& out, // [..., hidden_size]
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor const& input, // [..., hidden_size]
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
|
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
TORCH_CHECK(scales.is_contiguous());
|
TORCH_CHECK(scales.is_contiguous());
|
||||||
|
@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
@ -70,8 +70,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||||
@ -84,7 +84,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|||||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
@ -148,8 +148,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
std::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
|||||||
|
|
||||||
|
|
||||||
static inline std::optional<at::ScalarType> maybe_scalartype(
|
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||||
c10::optional<at::Tensor> const& t) {
|
std::optional<at::Tensor> const& t) {
|
||||||
if (!t) {
|
if (!t) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
} else {
|
} else {
|
||||||
|
@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
|
|||||||
torch::Tensor const& A, // MxK matrix
|
torch::Tensor const& A, // MxK matrix
|
||||||
torch::Tensor const& B, // KxN prepacked matrix
|
torch::Tensor const& B, // KxN prepacked matrix
|
||||||
torch::Tensor& D, // MxN matrix
|
torch::Tensor& D, // MxN matrix
|
||||||
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
std::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||||
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
std::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||||
c10::optional<int64_t> maybe_group_size,
|
std::optional<int64_t> maybe_group_size,
|
||||||
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
std::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||||
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
std::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||||
{
|
{
|
||||||
static_assert(!with_group_zeropoints || with_group_scales);
|
static_assert(!with_group_zeropoints || with_group_scales);
|
||||||
|
|
||||||
|
@ -13,23 +13,23 @@ struct MMArgs {
|
|||||||
torch::Tensor const& A;
|
torch::Tensor const& A;
|
||||||
torch::Tensor const& B;
|
torch::Tensor const& B;
|
||||||
vllm::ScalarType const& b_type;
|
vllm::ScalarType const& b_type;
|
||||||
c10::optional<at::ScalarType> const& maybe_out_type;
|
std::optional<at::ScalarType> const& maybe_out_type;
|
||||||
c10::optional<torch::Tensor> const& maybe_group_scales;
|
std::optional<torch::Tensor> const& maybe_group_scales;
|
||||||
c10::optional<torch::Tensor> const& maybe_group_zeros;
|
std::optional<torch::Tensor> const& maybe_group_zeros;
|
||||||
c10::optional<int64_t> maybe_group_size;
|
std::optional<int64_t> maybe_group_size;
|
||||||
c10::optional<torch::Tensor> const& maybe_channel_scales;
|
std::optional<torch::Tensor> const& maybe_channel_scales;
|
||||||
c10::optional<torch::Tensor> const& maybe_token_scales;
|
std::optional<torch::Tensor> const& maybe_token_scales;
|
||||||
c10::optional<std::string> maybe_schedule;
|
std::optional<std::string> maybe_schedule;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SupportedSchedulesArgs {
|
struct SupportedSchedulesArgs {
|
||||||
at::ScalarType a_type;
|
at::ScalarType a_type;
|
||||||
vllm::ScalarType b_type;
|
vllm::ScalarType b_type;
|
||||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||||
c10::optional<at::ScalarType> maybe_group_zeros_type;
|
std::optional<at::ScalarType> maybe_group_zeros_type;
|
||||||
c10::optional<at::ScalarType> maybe_channel_scales_type;
|
std::optional<at::ScalarType> maybe_channel_scales_type;
|
||||||
c10::optional<at::ScalarType> maybe_token_scales_type;
|
std::optional<at::ScalarType> maybe_token_scales_type;
|
||||||
c10::optional<at::ScalarType> maybe_out_type;
|
std::optional<at::ScalarType> maybe_out_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
torch::Tensor mm_dispatch(MMArgs args);
|
torch::Tensor mm_dispatch(MMArgs args);
|
||||||
|
@ -10,7 +10,7 @@ struct PrepackBArgs {
|
|||||||
torch::Tensor const& B;
|
torch::Tensor const& B;
|
||||||
at::ScalarType a_type;
|
at::ScalarType a_type;
|
||||||
vllm::ScalarType b_type;
|
vllm::ScalarType b_type;
|
||||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PrepackedLayoutB>
|
template <typename PrepackedLayoutB>
|
||||||
|
@ -10,11 +10,11 @@ using namespace vllm;
|
|||||||
|
|
||||||
std::vector<std::string> supported_schedules(
|
std::vector<std::string> supported_schedules(
|
||||||
at::ScalarType a_type, int64_t b_type_id,
|
at::ScalarType a_type, int64_t b_type_id,
|
||||||
c10::optional<at::ScalarType> maybe_group_scales_type,
|
std::optional<at::ScalarType> maybe_group_scales_type,
|
||||||
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
std::optional<at::ScalarType> maybe_group_zeros_type,
|
||||||
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
std::optional<at::ScalarType> maybe_channel_scales_type,
|
||||||
c10::optional<at::ScalarType> maybe_token_scales_type,
|
std::optional<at::ScalarType> maybe_token_scales_type,
|
||||||
c10::optional<at::ScalarType> maybe_out_type) {
|
std::optional<at::ScalarType> maybe_out_type) {
|
||||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||||
return supported_schedules_dispatch({
|
return supported_schedules_dispatch({
|
||||||
.a_type = a_type,
|
.a_type = a_type,
|
||||||
@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
|
|||||||
|
|
||||||
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||||
int64_t b_type_id,
|
int64_t b_type_id,
|
||||||
c10::optional<at::ScalarType> const& maybe_out_type,
|
std::optional<at::ScalarType> const& maybe_out_type,
|
||||||
c10::optional<torch::Tensor> const& maybe_group_scales,
|
std::optional<torch::Tensor> const& maybe_group_scales,
|
||||||
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
std::optional<torch::Tensor> const& maybe_group_zeros,
|
||||||
c10::optional<int64_t> maybe_group_size,
|
std::optional<int64_t> maybe_group_size,
|
||||||
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
std::optional<torch::Tensor> const& maybe_channel_scales,
|
||||||
c10::optional<torch::Tensor> const& maybe_token_scales,
|
std::optional<torch::Tensor> const& maybe_token_scales,
|
||||||
c10::optional<std::string> maybe_schedule) {
|
std::optional<std::string> maybe_schedule) {
|
||||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||||
return mm_dispatch({.A = A,
|
return mm_dispatch({.A = A,
|
||||||
.B = B,
|
.B = B,
|
||||||
@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
|||||||
|
|
||||||
torch::Tensor prepack_B(
|
torch::Tensor prepack_B(
|
||||||
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||||
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
std::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||||
return prepack_B_dispatch(
|
return prepack_B_dispatch(
|
||||||
{.B = B,
|
{.B = B,
|
||||||
|
@ -928,7 +928,7 @@ void paged_attention_custom_launcher(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||||
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
float k_scale, float v_scale) {
|
float k_scale, float v_scale) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
@ -1086,7 +1086,7 @@ void paged_attention(
|
|||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_context_len,
|
int64_t block_size, int64_t max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||||
const int head_size = query.size(2);
|
const int head_size = query.size(2);
|
||||||
if (kv_cache_dtype == "auto") {
|
if (kv_cache_dtype == "auto") {
|
||||||
|
@ -9,6 +9,6 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
|||||||
double scale, torch::Tensor& block_tables,
|
double scale, torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens, int64_t block_size,
|
torch::Tensor& context_lens, int64_t block_size,
|
||||||
int64_t max_context_len,
|
int64_t max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double k_scale,
|
const std::string& kv_cache_dtype, double k_scale,
|
||||||
double v_scale);
|
double v_scale);
|
||||||
|
@ -286,7 +286,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& bt_meta,
|
torch::Tensor const& bt_meta,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
|
@ -22,7 +22,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& e,
|
torch::Tensor const& e,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -30,7 +30,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& bt_meta,
|
torch::Tensor const& bt_meta,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias) {
|
std::optional<torch::Tensor> const& bias) {
|
||||||
// Checks for conformality
|
// Checks for conformality
|
||||||
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
||||||
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
||||||
|
Loading…
x
Reference in New Issue
Block a user