[Bugfix] Fix function names in test_block_fp8.py (#16033)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
f15e70d906
commit
dcc56d62da
@ -360,7 +360,7 @@ def fp8_perm(m, idx):
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||
M, K = a.shape
|
||||
|
||||
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
|
||||
@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||
return a, a_s, m_indices, inv_perm
|
||||
|
||||
|
||||
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
||||
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
||||
M = topk_weight.shape[0]
|
||||
out = out[inv_perm, ...]
|
||||
tmp_out = out.view(-1, topk, K)
|
||||
@ -401,7 +401,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
|
||||
a_q, a_s = per_token_group_quant_fp8(a, block_m)
|
||||
|
||||
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
|
||||
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
|
||||
num_groups, topk, block_m)
|
||||
|
||||
inter_out = torch.zeros((a_q.shape[0], N * 2),
|
||||
@ -419,7 +419,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
||||
|
||||
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||
|
||||
return final_out
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user