[Bugfix] Fix function names in test_block_fp8.py (#16033)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-04-03 19:01:34 -04:00 committed by GitHub
parent f15e70d906
commit dcc56d62da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -360,7 +360,7 @@ def fp8_perm(m, idx):
return m[idx, ...]
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
return a, a_s, m_indices, inv_perm
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
@ -401,7 +401,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
@ -419,7 +419,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out