[BugFix] Accuracy fix for llama4 int4 - improperly casted scales (#16801)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
6a0f547561
commit
7eb4255628
@ -13,7 +13,6 @@
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
output.zero_();
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
const float* topk_weights_ptr = nullptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
|
@ -422,6 +422,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Note: here we guard against accessing the TP and DP groups when
|
||||
# uninitialized (this happens when testing)
|
||||
|
@ -51,8 +51,8 @@ class Llama4MoE(nn.Module):
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(
|
||||
hidden_states.dtype)
|
||||
# psuedo-standard is that the router scores are floats
|
||||
router_scores = torch.sigmoid(router_scores.float())
|
||||
return (router_scores, router_indices.to(torch.int32))
|
||||
|
||||
def __init__(self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user