diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5dfcae08..d4bc2336 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -200,7 +200,7 @@ def apply_top_k_only( # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). k_index = k.sub_(1).unsqueeze(1) - top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index) + top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) logits.masked_fill_(logits < top_k_mask, -float("inf"))