[Bugfix] Fix Marlin MoE act order when is_k_full == False (#8741)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
5bf8789b2a
commit
d081da0064
3
csrc/core/exception.hpp
Normal file
3
csrc/core/exception.hpp
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define VLLM_IMPLIES(p, q) (!(p) || (q))
|
@ -25,6 +25,7 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "core/exception.hpp"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||||
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||||
@ -189,7 +190,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
|||||||
int load_groups =
|
int load_groups =
|
||||||
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
|
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
|
||||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||||
return load_groups * tb_n * 2;
|
return load_groups * tb_n * 4;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
int tb_scales = tb_groups * tb_n * 2;
|
int tb_scales = tb_groups * tb_n * 2;
|
||||||
@ -433,11 +434,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
|
|||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||||
const int* sorted_ids_ptr = (const int*)sorted_ids;
|
const int* sorted_ids_ptr = (const int*)sorted_ids;
|
||||||
const int4* s_ptr =
|
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
|
||||||
(const int4*)s +
|
|
||||||
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
|
|
||||||
prob_n / 8) *
|
|
||||||
expert_idx;
|
|
||||||
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
|
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
|
||||||
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
|
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
|
||||||
int* locks = (int*)workspace;
|
int* locks = (int*)workspace;
|
||||||
@ -521,6 +518,9 @@ torch::Tensor marlin_gemm_moe(
|
|||||||
" is not size_n = ", size_n);
|
" is not size_n = ", size_n);
|
||||||
num_groups = b_scales.size(1);
|
num_groups = b_scales.size(1);
|
||||||
|
|
||||||
|
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
|
||||||
|
"if is_k_full is false, has_act_order must be true");
|
||||||
|
|
||||||
if (has_act_order) {
|
if (has_act_order) {
|
||||||
if (is_k_full) {
|
if (is_k_full) {
|
||||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||||
|
@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref):
|
|||||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||||
|
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||||
def test_fused_marlin_moe(
|
def test_fused_marlin_moe(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -154,6 +155,7 @@ def test_fused_marlin_moe(
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
num_bits: int,
|
num_bits: int,
|
||||||
|
is_k_full: bool,
|
||||||
):
|
):
|
||||||
seed_everything(7)
|
seed_everything(7)
|
||||||
|
|
||||||
@ -166,6 +168,9 @@ def test_fused_marlin_moe(
|
|||||||
return
|
return
|
||||||
if group_size in (k, n):
|
if group_size in (k, n):
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
if not is_k_full:
|
||||||
|
return
|
||||||
|
|
||||||
quant_type = (scalar_types.uint4b8
|
quant_type = (scalar_types.uint4b8
|
||||||
if num_bits == 4 else scalar_types.uint8b128)
|
if num_bits == 4 else scalar_types.uint8b128)
|
||||||
@ -246,6 +251,7 @@ def test_fused_marlin_moe(
|
|||||||
w1_scale=scales1,
|
w1_scale=scales1,
|
||||||
w2_scale=scales2,
|
w2_scale=scales2,
|
||||||
num_bits=num_bits,
|
num_bits=num_bits,
|
||||||
|
is_k_full=is_k_full,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||||
@ -290,6 +296,7 @@ def test_fused_marlin_moe(
|
|||||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||||
|
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||||
def test_single_marlin_moe_multiply(
|
def test_single_marlin_moe_multiply(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply(
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
num_bits: int,
|
num_bits: int,
|
||||||
|
is_k_full: bool,
|
||||||
):
|
):
|
||||||
if topk > e:
|
if topk > e:
|
||||||
return
|
return
|
||||||
@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply(
|
|||||||
return
|
return
|
||||||
if group_size == k:
|
if group_size == k:
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
if not is_k_full:
|
||||||
|
return
|
||||||
|
|
||||||
quant_type = (scalar_types.uint4b8
|
quant_type = (scalar_types.uint4b8
|
||||||
if num_bits == 4 else scalar_types.uint8b128)
|
if num_bits == 4 else scalar_types.uint8b128)
|
||||||
@ -339,7 +350,8 @@ def test_single_marlin_moe_multiply(
|
|||||||
sort_indices = stack_and_dev(sort_indices_l)
|
sort_indices = stack_and_dev(sort_indices_l)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
marlin_output = single_marlin_moe(a,
|
marlin_output = single_marlin_moe(
|
||||||
|
a,
|
||||||
qweight,
|
qweight,
|
||||||
scales,
|
scales,
|
||||||
score,
|
score,
|
||||||
@ -347,7 +359,9 @@ def test_single_marlin_moe_multiply(
|
|||||||
sort_indices,
|
sort_indices,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
num_bits=num_bits)
|
num_bits=num_bits,
|
||||||
|
is_k_full=is_k_full,
|
||||||
|
)
|
||||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||||
|
@ -21,6 +21,7 @@ def single_marlin_moe(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
|
is_k_full: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes the multiplication of hidden_states with expert
|
This function computes the multiplication of hidden_states with expert
|
||||||
@ -86,7 +87,7 @@ def single_marlin_moe(
|
|||||||
|
|
||||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||||
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
|
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
|
||||||
block_size_m, True, False)
|
block_size_m, True, False)
|
||||||
|
|
||||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||||
@ -107,6 +108,7 @@ def fused_marlin_moe(
|
|||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
|
is_k_full: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@ -199,7 +201,7 @@ def fused_marlin_moe(
|
|||||||
M,
|
M,
|
||||||
2 * N,
|
2 * N,
|
||||||
K,
|
K,
|
||||||
True,
|
is_k_full,
|
||||||
E,
|
E,
|
||||||
topk,
|
topk,
|
||||||
block_size_m,
|
block_size_m,
|
||||||
@ -223,7 +225,7 @@ def fused_marlin_moe(
|
|||||||
M,
|
M,
|
||||||
K,
|
K,
|
||||||
N,
|
N,
|
||||||
True,
|
is_k_full,
|
||||||
E,
|
E,
|
||||||
topk,
|
topk,
|
||||||
block_size_m,
|
block_size_m,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user