[V1] Fix: make sure k_index
is int64 for apply_top_k_only
(#15907)
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
parent
24b7fb455a
commit
6efb195a6e
@ -200,7 +200,7 @@ def apply_top_k_only(
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user