[MISC] Replace c10::optional with std::optional (#11730)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
47831430cc
commit
4068f4b5b5
@ -53,7 +53,7 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
@ -176,7 +176,7 @@ void paged_attention_v1(
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
|
@ -54,7 +54,7 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
@ -187,7 +187,7 @@ void paged_attention_v2(
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
|
@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -459,7 +459,7 @@ void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -781,7 +781,7 @@ void paged_attention_v2(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
|
@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const c10::optional<torch::Tensor>& bias // [OC]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const torch::Tensor& azp_adj, // [OC]
|
||||
const c10::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const c10::optional<torch::Tensor>& bias // [OC]
|
||||
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||
// Checks for conformality
|
||||
@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const torch::Tensor& azp_adj,
|
||||
const c10::optional<torch::Tensor>& azp,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
@ -68,7 +68,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@ -301,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
@ -67,7 +67,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@ -299,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
c10::optional<torch::Tensor> const& tensor,
|
||||
std::optional<torch::Tensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
|
@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &conv_states,
|
||||
const c10::optional<at::Tensor> &query_start_loc,
|
||||
const c10::optional<at::Tensor> &cache_indices,
|
||||
const c10::optional<at::Tensor> &has_initial_state,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor> &cache_seqlens_,
|
||||
const c10::optional<at::Tensor> &conv_state_indices_,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
|
@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
const c10::optional<at::Tensor>& D,
|
||||
const c10::optional<at::Tensor>& delta_bias,
|
||||
const std::optional<at::Tensor>& D,
|
||||
const std::optional<at::Tensor>& delta_bias,
|
||||
const torch::Tensor ssm_states,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t pad_slot_id) {
|
||||
|
||||
@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const c10::optional<torch::Tensor> &D_,
|
||||
const c10::optional<torch::Tensor> &z_,
|
||||
const c10::optional<torch::Tensor> &delta_bias_,
|
||||
const std::optional<torch::Tensor> &D_,
|
||||
const std::optional<torch::Tensor> &z_,
|
||||
const std::optional<torch::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor> &query_start_loc,
|
||||
const c10::optional<torch::Tensor> &cache_indices,
|
||||
const c10::optional<torch::Tensor> &has_initial_state,
|
||||
const std::optional<torch::Tensor> &query_start_loc,
|
||||
const std::optional<torch::Tensor> &cache_indices,
|
||||
const std::optional<torch::Tensor> &has_initial_state,
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
|
46
csrc/ops.h
46
csrc/ops.h
@ -33,7 +33,7 @@ void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
@ -44,7 +44,7 @@ void paged_attention_v2(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
@ -153,15 +153,15 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
||||
|
||||
@ -169,7 +169,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
torch::Tensor& e, torch::Tensor const& a);
|
||||
@ -177,11 +177,11 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale,
|
||||
c10::optional<torch::Tensor> const& azp);
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales,
|
||||
c10::optional<torch::Tensor> const& azp);
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
@ -198,34 +198,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& scale_ub);
|
||||
std::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const torch::Tensor& A, const torch::Tensor& B,
|
||||
const torch::Tensor& C,
|
||||
const c10::optional<torch::Tensor>& D_,
|
||||
const c10::optional<torch::Tensor>& z_,
|
||||
const c10::optional<torch::Tensor>& delta_bias_,
|
||||
const std::optional<torch::Tensor>& D_,
|
||||
const std::optional<torch::Tensor>& z_,
|
||||
const std::optional<torch::Tensor>& delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor>& query_start_loc,
|
||||
const c10::optional<torch::Tensor>& cache_indices,
|
||||
const c10::optional<torch::Tensor>& has_initial_state,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor>& cache_seqlens_,
|
||||
const c10::optional<at::Tensor>& conv_state_indices_,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& conv_states,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& scale,
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
|
||||
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
|
@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
|
@ -51,7 +51,7 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -70,8 +70,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
|
@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
@ -84,7 +84,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
@ -148,8 +148,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
|
@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
|
||||
|
||||
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||
c10::optional<at::Tensor> const& t) {
|
||||
std::optional<at::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
|
@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
|
||||
torch::Tensor const& A, // MxK matrix
|
||||
torch::Tensor const& B, // KxN prepacked matrix
|
||||
torch::Tensor& D, // MxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
std::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
std::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
std::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
|
@ -13,23 +13,23 @@ struct MMArgs {
|
||||
torch::Tensor const& A;
|
||||
torch::Tensor const& B;
|
||||
vllm::ScalarType const& b_type;
|
||||
c10::optional<at::ScalarType> const& maybe_out_type;
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
c10::optional<int64_t> maybe_group_size;
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales;
|
||||
c10::optional<std::string> maybe_schedule;
|
||||
std::optional<at::ScalarType> const& maybe_out_type;
|
||||
std::optional<torch::Tensor> const& maybe_group_scales;
|
||||
std::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
std::optional<int64_t> maybe_group_size;
|
||||
std::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
std::optional<torch::Tensor> const& maybe_token_scales;
|
||||
std::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_out_type;
|
||||
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||
std::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
std::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
std::optional<at::ScalarType> maybe_token_scales_type;
|
||||
std::optional<at::ScalarType> maybe_out_type;
|
||||
};
|
||||
|
||||
torch::Tensor mm_dispatch(MMArgs args);
|
||||
|
@ -10,7 +10,7 @@ struct PrepackBArgs {
|
||||
torch::Tensor const& B;
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
|
@ -10,11 +10,11 @@ using namespace vllm;
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
at::ScalarType a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_out_type) {
|
||||
std::optional<at::ScalarType> maybe_group_scales_type,
|
||||
std::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
std::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
std::optional<at::ScalarType> maybe_token_scales_type,
|
||||
std::optional<at::ScalarType> maybe_out_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_out_type,
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales,
|
||||
c10::optional<std::string> maybe_schedule) {
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<torch::Tensor> const& maybe_group_scales,
|
||||
std::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
std::optional<torch::Tensor> const& maybe_token_scales,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
|
||||
torch::Tensor prepack_B(
|
||||
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
std::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return prepack_B_dispatch(
|
||||
{.B = B,
|
||||
|
@ -928,7 +928,7 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
float k_scale, float v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
@ -1086,7 +1086,7 @@ void paged_attention(
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||
const int head_size = query.size(2);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
|
@ -9,6 +9,6 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
double scale, torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale);
|
||||
|
@ -286,7 +286,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
|
@ -22,7 +22,7 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -30,7 +30,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
||||
|
Loading…
x
Reference in New Issue
Block a user