[BugFix] Accuracy fix for llama4 int4 - improperly casted scales (#16801)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
6a0f547561
commit
7eb4255628
@ -13,7 +13,6 @@
|
|||||||
template <typename scalar_t, int bit, int GROUPS>
|
template <typename scalar_t, int bit, int GROUPS>
|
||||||
__global__ void moe_wna16_gemm_kernel(
|
__global__ void moe_wna16_gemm_kernel(
|
||||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||||
|
|
||||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||||
const uint32_t* __restrict__ qzeros,
|
const uint32_t* __restrict__ qzeros,
|
||||||
|
|
||||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
|||||||
if (token_index / top_k >= size_m) break;
|
if (token_index / top_k >= size_m) break;
|
||||||
|
|
||||||
num_valid_tokens = m + 1;
|
num_valid_tokens = m + 1;
|
||||||
if (blockIdx.z == 0 && offset_n < size_n)
|
|
||||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
|
||||||
|
|
||||||
if (expert_id != -1) {
|
if (expert_id != -1) {
|
||||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
auto options =
|
output.zero_();
|
||||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
|
||||||
|
|
||||||
const int num_experts = b_qweight.size(0);
|
const int num_experts = b_qweight.size(0);
|
||||||
const int size_m = input.size(0);
|
const int size_m = input.size(0);
|
||||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
const uint32_t* b_qzeros_ptr;
|
const uint32_t* b_qzeros_ptr;
|
||||||
if (b_qzeros.has_value())
|
if (b_qzeros.has_value())
|
||||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||||
const float* topk_weights_ptr;
|
const float* topk_weights_ptr = nullptr;
|
||||||
if (topk_weights.has_value())
|
if (topk_weights.has_value())
|
||||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||||
|
|
||||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||||
|
@ -422,6 +422,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
# Note: here we guard against accessing the TP and DP groups when
|
# Note: here we guard against accessing the TP and DP groups when
|
||||||
# uninitialized (this happens when testing)
|
# uninitialized (this happens when testing)
|
||||||
|
@ -51,8 +51,8 @@ class Llama4MoE(nn.Module):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(
|
# psuedo-standard is that the router scores are floats
|
||||||
hidden_states.dtype)
|
router_scores = torch.sigmoid(router_scores.float())
|
||||||
return (router_scores, router_indices.to(torch.int32))
|
return (router_scores, router_indices.to(torch.int32))
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user