[V1][Sampler] Faster top-k only implementation (#15478)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-26 10:56:47 -07:00 committed by GitHub
parent 733e7c9e95
commit 35fad35a48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 5 deletions

View File

@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import Generator
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
DEVICE = "cuda"
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
def test_topk_impl_equivalance():
with torch.device(DEVICE):
generator = Generator(device=DEVICE).manual_seed(33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Random top-k values between 1 and 9.
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=bool), VOCAB_SIZE)
# Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
# Top-p + top-k
no_op_top_p = torch.tensor([1.0])
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
assert torch.allclose(result1, result2)

View File

@ -19,6 +19,12 @@ except ImportError:
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self):
super().__init__()
@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling."""
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
@ -136,10 +146,18 @@ def apply_top_k_top_p(
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if k is None and p is None:
return logits
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
@ -153,7 +171,7 @@ def apply_top_k_top_p(
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
@ -164,6 +182,31 @@ def apply_top_k_top_p(
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],

View File

@ -87,6 +87,12 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_random: