[V1][TPU] Enable Top K (#15489)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-17 20:18:11 +02:00 committed by GitHub
parent 5989f4684d
commit eb5819b2d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 69 additions and 26 deletions

View File

@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import random
import pytest
from vllm import LLM, envs
@ -39,3 +41,19 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12)) for _ in range(B)
]
# Make sure first two reqs have the same K
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
assert output[0].outputs[0].text == output[1].outputs[0].text

View File

@ -5,7 +5,8 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)

View File

@ -103,7 +103,6 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
@ -685,11 +684,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":

View File

@ -72,14 +72,7 @@ class TopKTopPSampler(nn.Module):
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
self.forward = self.forward_tpu
else:
self.forward = self.forward_native
@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu(
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
logits = apply_top_k_only(logits, k)
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
@ -224,7 +227,7 @@ def apply_top_k_only(
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))

View File

@ -10,7 +10,7 @@ DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
top_k=0,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
@ -99,11 +99,13 @@ class TPUSupportedSamplingMetadata:
fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# fill_slice(input_batch.top_p_cpu_tensor)
# fill_slice(input_batch.top_k_cpu_tensor)
fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"])
fill_slice(input_batch.top_k_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["top_k"])
# TODO Temporarily disabled until sampling options are enabled
# fill_slice(input_batch.top_p_cpu_tensor,
# DEFAULT_SAMPLING_PARAMS["top_p"])
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
@ -112,6 +114,7 @@ class TPUSupportedSamplingMetadata:
all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device))

View File

@ -920,14 +920,19 @@ class TPUModelRunner:
device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs)
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs >= min(num_tokens, self.max_num_reqs):
break
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")
def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.")
logger.info("Compiling sampling with different num_reqs.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings: