[BugFix] Accuracy fix for llama4 int4 - improperly casted scales (#16801)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-04-18 01:13:29 -04:00 committed by GitHub
parent 6a0f547561
commit 7eb4255628
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 9 deletions

View File

@ -13,7 +13,6 @@
template <typename scalar_t, int bit, int GROUPS> template <typename scalar_t, int bit, int GROUPS>
__global__ void moe_wna16_gemm_kernel( __global__ void moe_wna16_gemm_kernel(
const scalar_t* __restrict__ input, scalar_t* __restrict__ output, const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
const uint32_t* __restrict__ qzeros, const uint32_t* __restrict__ qzeros,
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
if (token_index / top_k >= size_m) break; if (token_index / top_k >= size_m) break;
num_valid_tokens = m + 1; num_valid_tokens = m + 1;
if (blockIdx.z == 0 && offset_n < size_n)
output[token_index * size_n + offset_n] = Dtype::int2num(0);
if (expert_id != -1) { if (expert_id != -1) {
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit) { int64_t BLOCK_SIZE_K, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto options = output.zero_();
torch::TensorOptions().dtype(input.dtype()).device(input.device());
const int num_experts = b_qweight.size(0); const int num_experts = b_qweight.size(0);
const int size_m = input.size(0); const int size_m = input.size(0);
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
const uint32_t* b_qzeros_ptr; const uint32_t* b_qzeros_ptr;
if (b_qzeros.has_value()) if (b_qzeros.has_value())
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>(); b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
const float* topk_weights_ptr; const float* topk_weights_ptr = nullptr;
if (topk_weights.has_value()) if (topk_weights.has_value())
topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
int groups_per_block_row = BLOCK_SIZE_K / group_size; int groups_per_block_row = BLOCK_SIZE_K / group_size;
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");

View File

@ -422,6 +422,7 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Note: here we guard against accessing the TP and DP groups when # Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing) # uninitialized (this happens when testing)

View File

@ -51,8 +51,8 @@ class Llama4MoE(nn.Module):
renormalize: bool, renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1) router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to( # psuedo-standard is that the router scores are floats
hidden_states.dtype) router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32)) return (router_scores, router_indices.to(torch.int32))
def __init__(self, def __init__(self,