From d06ba4ed3f9a5929eabb404842a5c02da42e960b Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 15 Apr 2025 11:05:22 +0800 Subject: [PATCH] [Kernel] moe wna16 marlin kernel (#14447) Signed-off-by: Jinzhen Lin Co-authored-by: Michael Goin Co-authored-by: mgoin --- CMakeLists.txt | 52 +- csrc/moe/marlin_moe_wna16/generate_kernels.py | 103 + csrc/moe/marlin_moe_wna16/kernel.h | 44 + csrc/moe/marlin_moe_wna16/marlin_template.h | 1917 +++++++++++++++++ csrc/moe/marlin_moe_wna16/ops.cu | 927 ++++++++ csrc/moe/torch_bindings.cpp | 19 +- csrc/quantization/gptq_marlin/marlin.cuh | 9 +- .../gptq_marlin/marlin_dtypes.cuh | 10 +- tests/kernels/test_moe.py | 254 ++- vllm/_custom_ops.py | 46 + .../layers/fused_moe/fused_marlin_moe.py | 319 +-- .../layers/fused_moe/fused_moe.py | 20 +- vllm/model_executor/layers/fused_moe/layer.py | 1 + .../layers/quantization/awq_marlin.py | 35 +- .../layers/quantization/gptq_marlin.py | 37 +- .../layers/quantization/utils/marlin_utils.py | 13 + 16 files changed, 3477 insertions(+), 329 deletions(-) create mode 100644 csrc/moe/marlin_moe_wna16/generate_kernels.py create mode 100644 csrc/moe/marlin_moe_wna16/kernel.h create mode 100644 csrc/moe/marlin_moe_wna16/marlin_template.h create mode 100644 csrc/moe/marlin_moe_wna16/ops.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index a0c25df6..4f4b20d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -609,21 +609,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) - set(MARLIN_MOE_SRC - "csrc/moe/marlin_kernels/marlin_moe_kernel.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_moe_ops.cu") + # + # For the Marlin MOE kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MOE_MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) + file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + + message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} + RESULT_VARIABLE moe_marlin_generation_result + OUTPUT_VARIABLE moe_marlin_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ) + + if (NOT moe_marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin MOE generation failed." + " Result: \"${moe_marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") + else() + set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} + CACHE STRING "Last run Marlin MOE generate script hash" FORCE) + message(STATUS "Marlin MOE generation completed successfully.") + endif() + else() + message(STATUS "Marlin MOE generation script has not changed, skipping generation.") + endif() + + file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") set_gencode_flags_for_srcs( - SRCS "${MARLIN_MOE_SRC}" + SRCS "${MOE_WNAA16_MARLIN_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") - list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") + list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() message(STATUS "Not building Marlin MOE kernels as no compatible archs found" diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py new file mode 100644 index 00000000..d1c0d92f --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ("template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{'true' if has_act_order else 'false'}}, " + "{{'true' if has_zp else 'false'}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );") + +# int8 with zero point case (vllm::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, -1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + has_zp = "B" not in scalar_type + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + has_act_order = group_blocks == 0 + if has_zp and has_act_order: + continue + if thread_configs[2] == 256: + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + has_act_order=has_act_order, + has_zp=has_zp, + group_blocks=group_blocks, + is_zp_float=False, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h new file mode 100644 index 00000000..3d92660e --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -0,0 +1,44 @@ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "quantization/gptq_marlin/marlin.cuh" +#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ + bool use_fp32_reduce + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h new file mode 100644 index 00000000..205b308f --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -0,0 +1,1917 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "quantization/gptq_marlin/marlin.cuh" +#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace MARLIN_NAMESPACE_NAME + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, + typename ScalarType::FragB& frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, + typename ScalarType::FragB& frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + extern __shared__ int4 sh[]; + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + const int group_size = + (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; + const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int zp_expert_stride = + is_zp_float ? prob_n * prob_k / group_size / 8 + : prob_n * prob_k / group_size / (pack_factor * 4); + + // parallel: num valid moe blocks + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int parallel = num_tokens_past_padded / moe_block_size; + int num_valid_blocks = parallel; + if (is_ep) { + for (int i = 0; i < parallel; i++) { + if (expert_ids_ptr[i] == -1) num_valid_blocks--; + } + } + int num_invalid_blocks = parallel - num_valid_blocks; + parallel = num_valid_blocks; + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int block_id = -1; + int64_t expert_id = 0; // use int64 to avoid computation result overflow + int old_expert_id = 0; + int64_t B_expert_off = 0; + + int4* sh_block_sorted_ids_int4 = sh; + int32_t* sh_block_sorted_ids = + reinterpret_cast(sh_block_sorted_ids_int4); + int4* sh_block_topk_weights_int4 = + sh_block_sorted_ids_int4 + moe_block_size / 4; + scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4; + + int32_t block_num_valid_tokens = 0; + int32_t locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // read moe block data given block_id + // block_sorted_ids / block_num_valid_tokens / block_topk_weights + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + #pragma unroll + for (int i = 0; i < moe_block_size / 4; i++) { + int4 sorted_token_ids_int4 = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); + #pragma unroll + for (int j = 0; j < 4; j++) { + if (sorted_token_ids[j] >= prob_m * top_k) { + block_num_valid_tokens = i * 4 + j; + break; + } + } + if (block_num_valid_tokens != moe_block_size) break; + } + + __syncthreads(); + int tid4 = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { + sh_block_sorted_ids_int4[tid4] = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + + if (mul_topk_weights) { + #pragma unroll + for (int i = 0; i < 4; i++) { + sh_block_topk_weights[tid4 * 4 + i] = + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + } + } + } + __syncthreads(); + }; + + // when move to next moe block, find the next block_id and expert_id + // and then read moe block data + auto update_next_moe_block_data = [&]() { + if (par_id >= parallel) return; + + old_expert_id = expert_id; + if (num_invalid_blocks > 0) { + int skip_count = block_id == -1 ? par_id : 0; + block_id++; + for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + expert_id = expert_ids_ptr[i]; + if (expert_id != -1) { + if (skip_count == 0) { + block_id = i; + break; + }; + skip_count--; + }; + } + } else { + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; + } + + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); + scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; + if constexpr (has_zp) { + zp_ptr += (expert_id - old_expert_id) * zp_expert_stride; + } + if constexpr (has_act_order) { + g_idx += (expert_id - old_expert_id) * prob_k; + } + + read_moe_block_data(block_id); + }; + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = + div_ceil(block_num_valid_tokens, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int col = slice_col * 16 * thread_n_blocks / 8 + + threadIdx.x % threads_per_m; + C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + slice_col = 0; + par_id++; + update_next_moe_block_data(); + } + }; + + update_next_moe_block_data(); + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh_new; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_b; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + int a_remaining_load_count_in_slice = stages; + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || + a_remaining_load_count_in_slice > 0) { + a_remaining_load_count_in_slice--; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_block_sorted_ids[row] / top_k; + int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + B_ptr[i] + j + B_expert_off); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = + ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant(zp_quant_0, frag_zp_0); + dequant(zp_quant_1, frag_zp_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + } + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant(b_quant_0, frag_b0); + dequant(b_quant_1, frag_b1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); + + } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); + } else if constexpr (has_zp && is_zp_float && group_blocks != -1) { + if (is_new_zp) + frag_zpf[k2][j] = __hmul2( + frag_zpf[k2][j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x); + scale_and_sub(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + if (!is_th_active) { + return; + } + + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + if (!first) { + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } + } + } + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + C[true_idx] = c; + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4 && !has_zp) { + res = __hmul2(res, s[0]); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + int row = c_gl_wr / c_gl_stride; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; + scalar_t2 topk_weight_score; + if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; + if (use_atomic_add && slice_count > 1 || mul_topk_weights) { + scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + scalar_t2 res = sh_red_half2[a]; + if (mul_topk_weights) { + res = __hmul2(res, topk_weight_score); + } + + if (use_atomic_add && slice_count > 1) { + atomicAdd(&C_half2[a], res); + } else { + C_half2[a] = res; + }; + } + } else { + C[true_idx] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + fetch_col_scale_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + a_remaining_load_count_in_slice = 0; + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + #pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8 && !has_zp) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + if (slice_row) a_remaining_load_count_in_slice = stages; + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu new file mode 100644 index 00000000..a16e955a --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -0,0 +1,927 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "kernel.h" +#include "core/registration.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) {}; + +} // namespace marlin + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) { + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size); + int32_t block_sorted_ids[moe_block_size]; + int block_num_valid_tokens = 0; + int64_t old_expert_id = 0; + int64_t expert_id = 0; + int row_stride = size_k * sizeof(half) / 16; + + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + int4* tmp_block_sorted_ids = reinterpret_cast(block_sorted_ids); + for (int i = 0; i < moe_block_size / 4; i++) { + tmp_block_sorted_ids[i] = + ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + } + for (int i = 0; i < moe_block_size; i++) { + if (block_sorted_ids[i] >= size_m * top_k) { + block_num_valid_tokens = i; + break; + }; + } + }; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int in_offset = (row / top_k) * row_stride; + int out_offset = row * row_stride; + + half const* a_row_half = + reinterpret_cast(a_int4_ptr + in_offset); + half* out_half = reinterpret_cast(out_int4_ptr + out_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) { + old_expert_id = expert_id; + int tmp_expert_id = expert_ids_ptr[index]; + if (tmp_expert_id == -1) continue; + expert_id = tmp_expert_id; + perm_int_ptr += (expert_id - old_expert_id) * size_k; + read_moe_block_data(index); + + for (int i = 0; i < block_num_valid_tokens; i++) + permute_row(block_sorted_ids[i]); + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + + // shm size for block_sorted_ids/block_topk_weights + // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) + int sh_block_meta_size = tb_m * 4 * 2; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + + sh_g_idx_size + sh_block_meta_size; + + return total_size; +} + +bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float); + return cache_size <= max_shared_mem; +} + + #define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ + } + + #define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) + + #define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) + + #define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) + + #define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) + + // We currently have 4-bit models only with group_blocks == 4 + #define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ + true) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) + +template +MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, + int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool m_block_size_8, + bool has_act_order, bool has_zp, + int group_blocks, int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256) + GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256) + GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) + GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128) + + GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) + GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128) + + AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) + AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128) + + AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256) + AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128) + + return kernel; +} + +template +exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, + int group_size, bool has_act_order, + bool is_k_full, bool has_zp, + bool is_zp_float, int max_shared_mem) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = + thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + int count = 0; + constexpr int device_max_reg_size = 255 * 1024; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, + is_zp_float, max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, has_zp, is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, th_config.thread_n / 16, + th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, + group_blocks, th_config.num_threads, is_zp_float); + + if (kernel == MarlinDefault) continue; + + if (thread_m_blocks > 1) { + exec_cfg = {1, th_config}; + break; + } else { + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1024)); + allow_count = max(min(allow_count, 4), 1); + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; + } + } + + return exec_cfg; +} + +template +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, + void* sorted_token_ids, void* expert_ids, + void* num_tokens_past_padded, void* topk_weights, + int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + int thread_m_blocks = div_ceil(moe_block_size, 16); + bool m_block_size_8 = moe_block_size == 8; + + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids; + const int32_t* expert_ids_ptr = (const int32_t*)expert_ids; + const int32_t* num_tokens_past_padded_ptr = + (const int32_t*)num_tokens_past_padded; + const float* topk_weights_ptr = (const float*)topk_weights; + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + auto kernel = permute_cols_kernel<8>; + if (moe_block_size == 8) { + } else if (moe_block_size == 16) + kernel = permute_cols_kernel<16>; + else if (moe_block_size == 32) + kernel = permute_cols_kernel<32>; + else if (moe_block_size == 48) + kernel = permute_cols_kernel<48>; + else if (moe_block_size == 64) + kernel = permute_cols_kernel<64>; + else + TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size); + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_past_padded_ptr, prob_m, prob_k, top_k); + // clang-format on + A_ptr = a_tmp_ptr; + prob_m = prob_m * top_k; + top_k = 1; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem); + thread_tfg = exec_cfg.tb_cfg; + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, + prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem = ", max_shared_mem); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, + has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + + if (kernel == MarlinDefault) { + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, + topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); + // clang-format on +} + +} // namespace MARLIN_NAMESPACE_NAME + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + if (moe_block_size != 8) { + TORCH_CHECK(moe_block_size % 16 == 0, + "unsupported moe_block_size=", moe_block_size); + TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64, + "unsupported moe_block_size=", moe_block_size); + } + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(2) = ", b_q_weight.size(2), + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = + (b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c; + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK(c.size(0) == size_m * top_k, + "Shape mismatch: c.size(0) = ", c.size(0), + ", size_m * topk = ", size_m * top_k); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), + ", size_n = ", size_n); + } else { + c = torch::empty({size_m * top_k, size_n}, options); + } + + // Alloc C tmp buffer that is going to be used for the global reduce + torch::Tensor c_tmp; + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (use_fp32_reduce && !use_atomic_add) { + // max num of threadblocks is sms * 4 + long max_c_tmp_size = min( + (long)size_n * sorted_token_ids.size(0), + (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + if (moe_block_size == 8) max_c_tmp_size *= 2; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); + } else { + c_tmp = torch::empty({0}, options_fp32); + } + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + torch::Tensor g_idx, perm, a_tmp; + ; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) || + (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", g_idx.size(-1), + " and perm.size(-1) = ", perm.size(-1), + ", where size_k = ", size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + + if (has_act_order) { + a_tmp = torch::empty({size_m * top_k, size_k}, options); + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + a_tmp = torch::empty({0}, options); + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(1) = ", b_scales.size(1)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + + if (has_zp) { + TORCH_CHECK( + b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK( + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + if (is_zp_float) { + TORCH_CHECK(b_zeros.size(2) == size_n, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n = ", size_n); + TORCH_CHECK(num_groups == b_zeros.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); + } else { + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + } + + // Verify workspace size + TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n; + int min_workspace_size = min( + max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4); + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + c_tmp.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, + size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, + is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + } else { + TORCH_CHECK(false, + "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 718418e6..d0de4225 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); m.def( - "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " - "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " - "int b_q_type, SymInt size_m, " - "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " - "topk, " - "int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," + "Tensor sorted_token_ids," + "Tensor! expert_ids, Tensor! num_tokens_past_padded," + "Tensor! topk_weights, int moe_block_size, int top_k, " + "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "int size_m, int size_n, int size_k," + "bool is_full_k, bool use_atomic_add," + "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + // conditionally compiled so impl registration is in source file #endif diff --git a/csrc/quantization/gptq_marlin/marlin.cuh b/csrc/quantization/gptq_marlin/marlin.cuh index 74ccbac5..f3b44641 100644 --- a/csrc/quantization/gptq_marlin/marlin.cuh +++ b/csrc/quantization/gptq_marlin/marlin.cuh @@ -9,7 +9,11 @@ #include #include -namespace marlin { +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { // Marlin params @@ -23,6 +27,7 @@ static constexpr int pipe_stages = static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; static constexpr int tile_size = 16; static constexpr int max_par = 16; @@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace marlin +} // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index be06c09b..cc160548 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -5,7 +5,11 @@ #include #include -namespace marlin { +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { template class ScalarType {}; @@ -54,7 +58,7 @@ class ScalarType { using FragS = Vec; using FragZP = Vec; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } @@ -74,6 +78,6 @@ class ScalarType { #endif }; -} // namespace marlin +} // namespace MARLIN_NAMESPACE_NAME #endif diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 3f4dd3cf..425f3698 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -11,16 +11,14 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) -from vllm import _custom_ops as ops +from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, + torch_moe_single) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) + awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE @@ -287,14 +285,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("n", [128, 1024]) +@pytest.mark.parametrize("k", [256, 2048]) +@pytest.mark.parametrize("e", [4, 12]) +@pytest.mark.parametrize("topk", [2, 3]) +@pytest.mark.parametrize("ep_size", [1, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( @@ -303,9 +304,12 @@ def test_fused_marlin_moe( k: int, e: int, topk: int, + ep_size: int, + dtype: torch.dtype, group_size: int, act_order: bool, num_bits: int, + has_zp: bool, is_k_full: bool, ): current_platform.seed_everything(7) @@ -316,75 +320,110 @@ def test_fused_marlin_moe( return if group_size in (k, n): return + if has_zp: + return else: if not is_k_full: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - dtype = torch.float16 + if has_zp: + # we don't build kernel for int8 with zero + if num_bits == 8: + return + quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + quant_type = scalar_types.uint4b8 \ + if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + w_ref1_l = [] qweight1_l = [] scales1_l = [] + zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - w_ref1_l.append(w_ref1) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) + if has_zp: + w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + zeros1_l.append(zeros1) + else: + test_perm = torch.randperm(k) + quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) - g_idx1 = stack_and_dev(g_idx1_l) - sort_indices1 = stack_and_dev(sort_indices1_l) + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None w_ref2_l = [] qweight2_l = [] scales2_l = [] + zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - w_ref2_l.append(w_ref2) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) + if has_zp: + w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + zeros2_l.append(zeros2) + else: + test_perm = torch.randperm(n) + quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) - g_idx2 = stack_and_dev(g_idx2_l) - sort_indices2 = stack_and_dev(sort_indices2_l) + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) - triton_output = fused_moe( - a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False, - ) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, @@ -394,111 +433,91 @@ def test_fused_marlin_moe( score, topk_weights, topk_ids, + global_num_experts=e, + expert_map=e_map, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, num_bits=num_bits, - is_k_full=is_k_full, - ) + is_k_full=is_k_full) - assert compute_max_diff(marlin_output, triton_output) < 4e-2 - - if ops.supports_moe_ops: - token_expert_indicies = torch.empty(m, - topk, - dtype=torch.int32, - device=a.device) - - opcheck(torch.ops._moe_C.topk_softmax, ( - topk_weights, - topk_ids, - token_expert_indicies, - score.float(), - )) - - block_size_m = 4 - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, - e) - - max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - zp = torch.empty((0, 0), - dtype=dtype, - device="cuda", - requires_grad=False) - opcheck(torch.ops._moe_C.marlin_gemm_moe, - (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id, - m, 2 * n, k, True, e, topk, block_size_m, True, False)) + torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") -@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("n", [128, 1024]) +@pytest.mark.parametrize("k", [256, 2048]) +@pytest.mark.parametrize("e", [4, 12]) +@pytest.mark.parametrize("topk", [2, 3]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False]) -@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_single_marlin_moe_multiply( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, - act_order: bool, - num_bits: int, - is_k_full: bool, -): - +def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype, group_size: int, + act_order: bool, num_bits: int, + has_zp: bool, is_k_full: bool): # Filter act_order if act_order: if group_size == -1: return - if group_size == k: + if group_size in (k, n): + return + if has_zp: return else: if not is_k_full: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - dtype = torch.float16 + if has_zp: + quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + quant_type = scalar_types.uint4b8 \ + if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w_ref_l = [] - qweights_l = [] + qweight_l = [] scales_l = [] + zeros_l = [] g_idx_l = [] sort_indices_l = [] for i in range(w.shape[0]): - test_perm = torch.randperm(k) - w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) - w_ref_l.append(w_ref) - qweights_l.append(qweight) - scales_l.append(scales) - g_idx_l.append(g_idx) - sort_indices_l.append(sort_indices) + if has_zp: + w_ref, qweight, scales, zeros = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + zeros_l.append(zeros) + else: + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweights_l).contiguous() + qweight = stack_and_dev(qweight_l).contiguous() scales = stack_and_dev(scales_l) - g_idx = stack_and_dev(g_idx_l) - sort_indices = stack_and_dev(sort_indices_l) + g_idx = stack_and_dev(g_idx_l) if g_idx_l else None + zeros = stack_and_dev(zeros_l) if zeros_l else None + sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) marlin_output = torch.ops.vllm.single_marlin_moe( @@ -510,13 +529,14 @@ def test_single_marlin_moe_multiply( renormalize=False, g_idx=g_idx, sort_indices=sort_indices, + w_zeros=zeros, num_bits=num_bits, is_k_full=is_k_full, ) - torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + torch_output = torch_moe_single(a, w_ref, score, topk) - assert compute_max_diff(marlin_output, torch_output) < 1e-2 + torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) def test_moe_align_block_size_opcheck(): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7a4c93ad..bd930bb9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1245,6 +1245,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies, gating_output) +def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, moe_block_size: int, + top_k: int, mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.ops._moe_C.moe_wna16_marlin_gemm( + input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace, + sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, + moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m, + size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, + is_zp_float) + + if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") @@ -1263,6 +1286,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): dtype=a.dtype, device=a.device) + @register_fake("_moe_C::moe_wna16_marlin_gemm") + def moe_wna16_marlin_gemm_fake(input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, top_k: int, + mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, + size_n: int, size_k: int, is_k_full: bool, + use_atomic_add: bool, use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.empty((size_m * top_k, size_n), + dtype=input.dtype, + device=input.device) + def reshape_and_cache( key: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index ee158d7e..62614a59 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -5,17 +5,16 @@ from typing import Optional import torch +import vllm._custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) -from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import direct_register_custom_op def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - assert num_bits == 4 - return scalar_types.uint4 + return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 @@ -27,9 +26,12 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -62,7 +64,7 @@ def single_marlin_moe( assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype == torch.float16 + assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] M, K = hidden_states.shape @@ -83,39 +85,54 @@ def single_marlin_moe( block_size_m = config['BLOCK_SIZE_M'] - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, block_size_m, E, expert_map) - max_workspace_size = (N // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=hidden_states.device, - requires_grad=False) + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * \ + (sorted_token_ids.size(0) // block_size_m) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) - has_zero_point = w_zeros is not None - if w_zeros is None: - w_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) + scalar_type = get_scalar_type(num_bits, w_zeros is not None) + intermediate_cache = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - if g_idx is None: - g_idx = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - if sort_indices is None: - sort_indices = torch.empty((0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - scalar_type = get_scalar_type(num_bits, has_zero_point) - - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K, - is_k_full, E, topk, block_size_m, True, False) + ops.moe_wna16_marlin_gemm(hidden_states, + intermediate_cache, + w, + scales, + w_zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type=scalar_type, + size_m=M, + size_n=N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False) + intermediate_cache = intermediate_cache.view(-1, topk, N) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -127,9 +144,12 @@ def single_marlin_moe_fake( gating_output: torch.Tensor, topk: int, renormalize: bool, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -144,24 +164,26 @@ direct_register_custom_op( ) -def fused_marlin_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: +def fused_marlin_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -196,27 +218,12 @@ def fused_marlin_moe( 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( num_bits // 2), "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype == torch.float16 + assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] - has_no_act_order = (g_idx1 is None and g_idx2 is None - and sort_indices1 is None and sort_indices2 is None) - has_all_act_order = (g_idx1 is not None and g_idx2 is not None - and sort_indices1 is not None - and sort_indices2 is not None) - assert has_no_act_order or has_all_act_order, ( - "g_idx and sorted_indices " - "must be all not None or must be all None") - - has_no_zp = w1_zeros is None and w2_zeros is None - has_all_zp = w1_zeros is not None and w2_zeros is not None - assert has_no_zp or has_all_zp, ("zero points must be both not None or " - "must be both None") - M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -234,122 +241,128 @@ def fused_marlin_moe( block_size_m = config["BLOCK_SIZE_M"] - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, block_size_m, global_num_experts, + expert_map) - max_workspace_size = (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=current_platform.device_type, - requires_grad=False) + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * \ + (sorted_token_ids.size(0) // block_size_m) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms * 4) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) - if has_no_zp: - w1_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - w2_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - - if has_no_act_order: - g_idx1 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - g_idx2 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - sort_indices1 = torch.empty((0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - sort_indices2 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - scalar_type1 = get_scalar_type(num_bits, has_all_zp) - scalar_type2 = get_scalar_type(num_bits, has_all_zp) + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype, ) + intermediate_cache13 = torch.empty( + (M * topk_ids.shape[1] * max(2 * N, K), ), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] + intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) + intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] + intermediate_cache3 = intermediate_cache3.view(-1, K) - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + use_atomic_add = hidden_states.dtype == torch.half or \ + torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + + intermediate_cache1 = ops.moe_wna16_marlin_gemm( hidden_states, + intermediate_cache1, w1, - sorted_token_ids, - topk_weights, - topk_ids, w1_scale, w1_zeros, g_idx1, sort_indices1, workspace, - scalar_type1.id, - M, - 2 * N, - K, - is_k_full, - E, - topk, - block_size_m, - True, - False, - ) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type=scalar_type1, + size_m=M, + size_n=2 * N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False) torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + if expert_map is not None: + intermediate_cache3.zero_() + + intermediate_cache3 = ops.moe_wna16_marlin_gemm( intermediate_cache2, + intermediate_cache3, w2, - sorted_token_ids, - topk_weights, - topk_ids, w2_scale, w2_zeros, g_idx2, sort_indices2, workspace, - scalar_type2.id, - M, - K, - N, - is_k_full, - E, - topk, - block_size_m, - False, - True, - ) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=1, + mul_topk_weights=True, + is_ep=expert_map is not None, + b_q_type=scalar_type2, + size_m=M * topk, + size_n=K, + size_k=N, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False).view(-1, topk, K) + output = hidden_states if inplace else torch.empty_like(hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + dim=1, + out=output) -def fused_marlin_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: +def fused_marlin_moe_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 38d739d5..2a988b86 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -773,6 +773,18 @@ def get_default_config( config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + elif is_marlin: + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break + return {"BLOCK_SIZE_M": block_size_m} + elif M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } else: config = { "BLOCK_SIZE_M": 64, @@ -780,14 +792,6 @@ def get_default_config( "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } return config diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 89a7548d..6e32e3e2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -472,6 +472,7 @@ class FusedMoE(torch.nn.Module): self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 + self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index cb1d5400..ef4a7765 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig, is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - check_marlin_supports_layer, marlin_make_empty_g_idx, - marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, - moe_awq_to_marlin_zero_points, verify_marlin_supported, - verify_marlin_supports_shape) + check_marlin_supports_layer, check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig): self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - if layer.local_num_experts > 32: - # For MoEs with many experts the moe_wna16 kernel is faster + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_one( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - else: - return AWQMoEMethod(self) + return AWQMoEMethod(self) return None @classmethod @@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase): layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) + device = layer.w13_qweight.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + layer.workspace = torch.zeros((sms * 4, ), + dtype=torch.int, + device=device, + requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device @@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase): activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " - "fused Marlin MoE method.") + if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" @@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase): router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, + workspace=layer.workspace, num_bits=self.quant_config.weight_bits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0615bb4a..52cd0a5b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) -from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, marlin_moe_permute_scales, - marlin_repeat_scales_on_all_ranks, verify_marlin_supported) + check_marlin_supported, check_moe_marlin_supports_layer, + marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, + verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - if layer.local_num_experts > 32: - # For MoEs with many experts the moe_wna16 kernel is faster + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_one( + f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - else: - return GPTQMarlinMoEMethod(self) + return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): torch.empty(num_experts, scales_size13, 2 * intermediate_size_per_partition, - dtype=torch.half), + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) @@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): torch.empty(num_experts, scales_size2, hidden_size, - dtype=torch.half), + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) @@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + device = layer.w13_qweight.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + layer.workspace = torch.zeros((sms * 4, ), + dtype=torch.int, + device=device, + requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Process act_order @@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): "Apply router weight on input is not supported for" "fused Marlin MoE method.") - # The input must currently be float16 - orig_dtype = x.dtype - x = x.half() - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, - is_k_full=self.is_k_full).to(orig_dtype) + workspace=layer.workspace, + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 5b2e3ca2..1ccfae91 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ group_size=group_size)[0] +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + hidden_size = layer.hidden_size + intermediate_size_per_partition = layer.intermediate_size_per_partition + + # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) + # down: (n, k) = (hidden_size, intermediate_size_per_partition) + # moe marlin requires n % 128 == 0 and k % 64 == 0 + return hidden_size % 128 == 0 and \ + intermediate_size_per_partition % max(64, group_size) == 0 and \ + group_size in [-1, 32, 64, 128] + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition //