[Bugfix] Fix Marlin MoE act order when is_k_full == False (#8741)

Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola 2024-09-29 03:19:40 +02:00 committed by GitHub
parent 5bf8789b2a
commit d081da0064
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 18 deletions

3
csrc/core/exception.hpp Normal file
View File

@ -0,0 +1,3 @@
#pragma once
#define VLLM_IMPLIES(p, q) (!(p) || (q))

View File

@ -25,6 +25,7 @@
#include <iostream> #include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
@ -189,7 +190,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int load_groups = int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2; return load_groups * tb_n * 4;
} else { } else {
int tb_scales = tb_groups * tb_n * 2; int tb_scales = tb_groups * tb_n * 2;
@ -433,11 +434,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights; const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids; const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
(const int4*)s +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace; int* locks = (int*)workspace;
@ -521,6 +518,9 @@ torch::Tensor marlin_gemm_moe(
" is not size_n = ", size_n); " is not size_n = ", size_n);
num_groups = b_scales.size(1); num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) { if (has_act_order) {
if (is_k_full) { if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");

View File

@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, m: int,
n: int, n: int,
@ -154,6 +155,7 @@ def test_fused_marlin_moe(
group_size: int, group_size: int,
act_order: bool, act_order: bool,
num_bits: int, num_bits: int,
is_k_full: bool,
): ):
seed_everything(7) seed_everything(7)
@ -166,6 +168,9 @@ def test_fused_marlin_moe(
return return
if group_size in (k, n): if group_size in (k, n):
return return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8 quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128) if num_bits == 4 else scalar_types.uint8b128)
@ -246,6 +251,7 @@ def test_fused_marlin_moe(
w1_scale=scales1, w1_scale=scales1,
w2_scale=scales2, w2_scale=scales2,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full,
) )
assert compute_max_diff(marlin_output, triton_output) < 4e-2 assert compute_max_diff(marlin_output, triton_output) < 4e-2
@ -290,6 +296,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply( def test_single_marlin_moe_multiply(
m: int, m: int,
n: int, n: int,
@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply(
group_size: int, group_size: int,
act_order: bool, act_order: bool,
num_bits: int, num_bits: int,
is_k_full: bool,
): ):
if topk > e: if topk > e:
return return
@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply(
return return
if group_size == k: if group_size == k:
return return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8 quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128) if num_bits == 4 else scalar_types.uint8b128)
@ -339,7 +350,8 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l) sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a, marlin_output = single_marlin_moe(
a,
qweight, qweight,
scales, scales,
score, score,
@ -347,7 +359,9 @@ def test_single_marlin_moe_multiply(
sort_indices, sort_indices,
topk, topk,
renormalize=False, renormalize=False,
num_bits=num_bits) num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2 assert compute_max_diff(marlin_output, torch_output) < 1e-2

View File

@ -21,6 +21,7 @@ def single_marlin_moe(
renormalize: bool, renormalize: bool,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes the multiplication of hidden_states with expert This function computes the multiplication of hidden_states with expert
@ -86,7 +87,7 @@ def single_marlin_moe(
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
block_size_m, True, False) block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@ -107,6 +108,7 @@ def fused_marlin_moe(
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
@ -199,7 +201,7 @@ def fused_marlin_moe(
M, M,
2 * N, 2 * N,
K, K,
True, is_k_full,
E, E,
topk, topk,
block_size_m, block_size_m,
@ -223,7 +225,7 @@ def fused_marlin_moe(
M, M,
K, K,
N, N,
True, is_k_full,
E, E,
topk, topk,
block_size_m, block_size_m,