[V1] Fix: make sure k_index is int64 for apply_top_k_only (#15907)

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
Brayden Zhong 2025-04-01 22:06:44 -04:00 committed by GitHub
parent 24b7fb455a
commit 6efb195a6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -200,7 +200,7 @@ def apply_top_k_only(
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))