[Kernel] moe wna16 marlin kernel (#14447)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin 2025-04-15 11:05:22 +08:00 committed by GitHub
parent 6b40996ae8
commit d06ba4ed3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 3477 additions and 329 deletions

View File

@ -609,21 +609,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_moe_ops.cu")
#
# For the Marlin MOE kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
)
if (NOT moe_marlin_generation_result EQUAL 0)
message(FATAL_ERROR "Marlin MOE generation failed."
" Result: \"${moe_marlin_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else()
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.")
endif()
else()
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif()
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
SRCS "${MOE_WNAA16_MARLIN_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"

View File

@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
has_act_order = group_blocks == 0
if has_zp and has_act_order:
continue
if thread_configs[2] == 256:
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()

View File

@ -0,0 +1,44 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,927 @@
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "kernel.h"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace MARLIN_NAMESPACE_NAME {
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
int32_t block_sorted_ids[moe_block_size];
int block_num_valid_tokens = 0;
int64_t old_expert_id = 0;
int64_t expert_id = 0;
int row_stride = size_k * sizeof(half) / 16;
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
for (int i = 0; i < moe_block_size / 4; i++) {
tmp_block_sorted_ids[i] =
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
}
for (int i = 0; i < moe_block_size; i++) {
if (block_sorted_ids[i] >= size_m * top_k) {
block_num_valid_tokens = i;
break;
};
}
};
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int in_offset = (row / top_k) * row_stride;
int out_offset = row * row_stride;
half const* a_row_half =
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
old_expert_id = expert_id;
int tmp_expert_id = expert_ids_ptr[index];
if (tmp_expert_id == -1) continue;
expert_id = tmp_expert_id;
perm_int_ptr += (expert_id - old_expert_id) * size_k;
read_moe_block_data(index);
for (int i = 0; i < block_num_valid_tokens; i++)
permute_row(block_sorted_ids[i]);
}
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256},
{64, 128, 128}};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128}};
typedef struct {
int blocks_per_sm;
thread_config_t tb_cfg;
} exec_config_t;
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
sh_zp_size = sh_s_size;
else if (num_bits == 4)
sh_zp_size = sh_s_size / 4;
else if (num_bits == 8)
sh_zp_size = sh_s_size / 2;
}
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault;
if (false) {
}
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
return kernel;
}
template <typename scalar_t>
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks,
bool m_block_size_8, int num_bits,
int group_size, bool has_act_order,
bool is_k_full, bool has_zp,
bool is_zp_float, int max_shared_mem) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
: small_batch_thread_configs;
int thread_configs_size =
thread_m_blocks > 1
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
int count = 0;
constexpr int device_max_reg_size = 255 * 1024;
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16;
}
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, th_config.thread_n / 16,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
group_blocks, th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) {
exec_cfg = {1, th_config};
break;
} else {
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1024));
allow_count = max(min(allow_count, 4), 1);
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
}
}
return exec_cfg;
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int num_bits = q_type.size_bits();
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
const int32_t* num_tokens_past_padded_ptr =
(const int32_t*)num_tokens_past_padded;
const float* topk_weights_ptr = (const float*)topk_weights;
int* locks = (int*)workspace;
if (has_act_order) {
// Permute A columns
auto kernel = permute_cols_kernel<8>;
if (moe_block_size == 8) {
} else if (moe_block_size == 16)
kernel = permute_cols_kernel<16>;
else if (moe_block_size == 32)
kernel = permute_cols_kernel<32>;
else if (moe_block_size == 48)
kernel = permute_cols_kernel<48>;
else if (moe_block_size == 64)
kernel = permute_cols_kernel<64>;
else
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<sms, default_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
// clang-format on
A_ptr = a_tmp_ptr;
prob_m = prob_m * top_k;
top_k = 1;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) has_act_order = false;
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
exec_config_t exec_cfg;
thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
exec_cfg = exec_config_t{1, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
} else {
// Auto config
exec_cfg = determine_exec_config<scalar_t>(
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem);
thread_tfg = exec_cfg.tb_cfg;
}
int num_threads = thread_tfg.num_threads;
thread_k = thread_tfg.thread_k;
thread_n = thread_tfg.thread_n;
int blocks = sms * exec_cfg.blocks_per_sm;
if (exec_cfg.blocks_per_sm > 1)
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
", num_groups = ", num_groups, ", group_size = ", group_size,
", thread_m_blocks = ", thread_m_blocks,
", thread_n_blocks = ", thread_n_blocks,
", thread_k_blocks = ", thread_k_blocks,
", num_bits = ", num_bits);
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
// clang-format on
}
} // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
int pack_factor = 32 / b_q_type.size_bits();
if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0,
"unsupported moe_block_size=", moe_block_size);
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
"unsupported moe_block_size=", moe_block_size);
}
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
", size_k = ", size_k);
// Verify B
TORCH_CHECK(
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_k = ", size_k,
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK(
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
"b_q_weight.size(2) = ", b_q_weight.size(2),
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
int actual_size_n =
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel
int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c;
if (c_or_none.has_value()) {
c = c_or_none.value();
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
TORCH_CHECK(c.size(0) == size_m * top_k,
"Shape mismatch: c.size(0) = ", c.size(0),
", size_m * topk = ", size_m * top_k);
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
", size_n = ", size_n);
} else {
c = torch::empty({size_m * top_k, size_n}, options);
}
// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
(long)size_n * sorted_token_ids.size(0),
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
if (moe_block_size == 8) max_c_tmp_size *= 2;
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
} else {
c_tmp = torch::empty({0}, options_fp32);
}
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp;
;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value();
perm = perm_or_none.value();
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Verify g_idx and perm
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
" and perm.size(-1) = ", perm.size(-1),
", where size_k = ", size_k);
} else {
g_idx = torch::empty({0}, options);
perm = torch::empty({0}, options);
a_tmp = torch::empty({0}, options);
}
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
if (has_act_order) {
a_tmp = torch::empty({size_m * top_k, size_k}, options);
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
a_tmp = torch::empty({0}, options);
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
} else {
b_zeros = torch::empty({0}, options);
}
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type.str());
}
if (has_zp && is_zp_float) {
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
"Computation type must be float16 (half) when using float zero "
"points.");
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
if (is_zp_float) {
TORCH_CHECK(b_zeros.size(2) == size_n,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n = ", size_n);
TORCH_CHECK(num_groups == b_zeros.size(1),
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
} else {
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
}
// Verify workspace size
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
MARLIN_NAMESPACE_NAME::min_thread_n);
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
int min_workspace_size = min(
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
}
return c;
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}

View File

@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
#endif

View File

@ -9,7 +9,11 @@
#include <cuda_runtime.h>
#include <iostream>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params
@ -23,6 +27,7 @@ static constexpr int pipe_stages =
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int max_thread_n = 256;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
#endif
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -5,7 +5,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t>
class ScalarType {};
@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif
};
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
#endif

View File

@ -11,16 +11,14 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
torch_moe_single)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE
@ -287,14 +285,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype])
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("ep_size", [1, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
@ -303,9 +304,12 @@ def test_fused_marlin_moe(
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
num_bits: int,
has_zp: bool,
is_k_full: bool,
):
current_platform.seed_everything(7)
@ -316,29 +320,57 @@ def test_fused_marlin_moe(
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
if has_zp:
# we don't build kernel for int8 with zero
if num_bits == 8:
return
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
if has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
@ -347,21 +379,33 @@ def test_fused_marlin_moe(
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
@ -370,21 +414,16 @@ def test_fused_marlin_moe(
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
@ -394,111 +433,91 @@ def test_fused_marlin_moe(
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits,
is_k_full=is_k_full,
)
is_k_full=is_k_full)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
if ops.supports_moe_ops:
token_expert_indicies = torch.empty(m,
topk,
dtype=torch.int32,
device=a.device)
opcheck(torch.ops._moe_C.topk_softmax, (
topk_weights,
topk_ids,
token_expert_indicies,
score.float(),
))
block_size_m = 4
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
e)
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
m, 2 * n, k, True, e, topk, block_size_m, True, False))
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int,
act_order: bool, num_bits: int,
has_zp: bool, is_k_full: bool):
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
if has_zp:
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
qweight_l = []
scales_l = []
zeros_l = []
g_idx_l = []
sort_indices_l = []
for i in range(w.shape[0]):
if has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
w[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe(
@ -510,13 +529,14 @@ def test_single_marlin_moe_multiply(
renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
w_zeros=zeros,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
torch_output = torch_moe_single(a, w_ref, score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
def test_moe_align_block_size_opcheck():

View File

@ -1245,6 +1245,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor, moe_block_size: int,
top_k: int, mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe")
@ -1263,6 +1286,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype=a.dtype,
device=a.device)
@register_fake("_moe_C::moe_wna16_marlin_gemm")
def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor,
moe_block_size: int, top_k: int,
mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int, is_k_full: bool,
use_atomic_add: bool, use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.empty((size_m * top_k, size_n),
dtype=input.dtype,
device=input.device)
def reshape_and_cache(
key: torch.Tensor,

View File

@ -5,17 +5,16 @@ from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
@ -27,9 +26,12 @@ def single_marlin_moe(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
@ -62,7 +64,7 @@ def single_marlin_moe(
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
M, K = hidden_states.shape
@ -83,39 +85,54 @@ def single_marlin_moe(
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, E, expert_map)
max_workspace_size = (N // 64) * 16
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=hidden_states.device,
device=device,
requires_grad=False)
has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty((0, 0),
scalar_type = get_scalar_type(num_bits, w_zeros is not None)
intermediate_cache = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
if sort_indices is None:
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
is_k_full, E, topk, block_size_m, True, False)
ops.moe_wna16_marlin_gemm(hidden_states,
intermediate_cache,
w,
scales,
w_zeros,
g_idx,
sort_indices,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type,
size_m=M,
size_n=N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False)
intermediate_cache = intermediate_cache.view(-1, topk, N)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@ -127,9 +144,12 @@ def single_marlin_moe_fake(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
@ -144,8 +164,7 @@ direct_register_custom_op(
)
def fused_marlin_moe(
hidden_states: torch.Tensor,
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
@ -153,15 +172,18 @@ def fused_marlin_moe(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
inplace: bool = False) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
@ -196,27 +218,12 @@ def fused_marlin_moe(
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
@ -234,106 +241,109 @@ def fused_marlin_moe(
block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, block_size_m, global_num_experts,
expert_map)
max_workspace_size = (max(2 * N, K) // 64) * 16
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * \
(sorted_token_ids.size(0) // block_size_m)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms * 4)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=current_platform.device_type,
device=device,
requires_grad=False)
if has_no_zp:
w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
intermediate_cache1,
w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
scalar_type1.id,
M,
2 * N,
K,
is_k_full,
E,
topk,
block_size_m,
True,
False,
)
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type=scalar_type1,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False)
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
intermediate_cache3,
w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
scalar_type2.id,
M,
K,
N,
is_k_full,
E,
topk,
block_size_m,
False,
True,
)
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
is_ep=expert_map is not None,
b_q_type=scalar_type2,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
dim=1,
out=output)
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
@ -341,15 +351,18 @@ def fused_marlin_moe_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states)

View File

@ -773,6 +773,18 @@ def get_default_config(
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif is_marlin:
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
return {"BLOCK_SIZE_M": block_size_m}
elif M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
else:
config = {
"BLOCK_SIZE_M": 64,
@ -780,14 +792,6 @@ def get_default_config(
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config

View File

@ -472,6 +472,7 @@ class FusedMoE(torch.nn.Module):
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize

View File

@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@ -136,11 +135,14 @@ class AWQMarlinConfig(QuantizationConfig):
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQMoEMethod(self)
return None
@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
)

View File

@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
@ -153,11 +153,14 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)
@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=torch.half),
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=torch.half),
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
is_k_full=self.is_k_full).to(orig_dtype)
workspace=layer.workspace,
is_k_full=self.is_k_full)

View File

@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
group_size=group_size)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
return hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0 and \
group_size in [-1, 32, 64, 128]
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //