#ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif #include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" #define MARLIN_KERNEL_PARAMS \ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ const int *__restrict__ g_idx, \ const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ num_tokens_past_padded_ptr, \ const float *__restrict__ topk_weights_ptr, int top_k, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool use_fp32_reduce namespace MARLIN_NAMESPACE_NAME { template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled const bool has_zp, // whether zero-points are enabled const int group_blocks, // number of consecutive 16x16 blocks // with a separate quantization scale const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin(MARLIN_KERNEL_PARAMS); }