From 7eb42556281d30436a3a988f2c9184ec63c59338 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 18 Apr 2025 01:13:29 -0400 Subject: [PATCH] [BugFix] Accuracy fix for llama4 int4 - improperly casted scales (#16801) Signed-off-by: Lucas Wilkinson --- csrc/moe/moe_wna16.cu | 10 +++------- vllm/model_executor/layers/fused_moe/layer.py | 1 + vllm/model_executor/models/llama4.py | 4 ++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 51ae76c1..7b6a111c 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -13,7 +13,6 @@ template __global__ void moe_wna16_gemm_kernel( const scalar_t* __restrict__ input, scalar_t* __restrict__ output, - const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, const uint32_t* __restrict__ qzeros, @@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel( if (token_index / top_k >= size_m) break; num_valid_tokens = m + 1; - if (blockIdx.z == 0 && offset_n < size_n) - output[token_index * size_n + offset_n] = Dtype::int2num(0); if (expert_id != -1) { int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); @@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto options = - torch::TensorOptions().dtype(input.dtype()).device(input.device()); + output.zero_(); const int num_experts = b_qweight.size(0); const int size_m = input.size(0); @@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, const uint32_t* b_qzeros_ptr; if (b_qzeros.has_value()) b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr(); - const float* topk_weights_ptr; + const float* topk_weights_ptr = nullptr; if (topk_weights.has_value()) - topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); + topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6e32e3e2..43fb3112 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -422,6 +422,7 @@ class FusedMoE(torch.nn.Module): if params_dtype is None: params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 51efbfe2..e5d1a671 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -51,8 +51,8 @@ class Llama4MoE(nn.Module): renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: router_scores, router_indices = fast_topk(gating_output, topk, dim=-1) - router_scores = torch.sigmoid(router_scores.float()).to( - hidden_states.dtype) + # psuedo-standard is that the router scores are floats + router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) def __init__(self,