[V1][TPU] Enable Top K (#15489)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-17 20:18:11 +02:00 committed by GitHub
parent 5989f4684d
commit eb5819b2d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 69 additions and 26 deletions

View File

@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random
import pytest import pytest
from vllm import LLM, envs from vllm import LLM, envs
@ -39,3 +41,19 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param. # Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42) sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params) output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12)) for _ in range(B)
]
# Make sure first two reqs have the same K
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
assert output[0].outputs[0].text == output[1].outputs[0].text

View File

@ -5,7 +5,8 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)
if not current_platform.is_tpu(): if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True) pytest.skip("This test needs a TPU.", allow_module_level=True)
@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6 TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p(): def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()): with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33) xm.set_rng_state(seed=33)

View File

@ -103,7 +103,6 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
@ -685,11 +684,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE": "VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
# Gap between padding buckets for the forward pass. So we have # Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...]. # 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP": "VLLM_TPU_BUCKET_PADDING_GAP":

View File

@ -72,13 +72,6 @@ class TopKTopPSampler(nn.Module):
"best performance, please install FlashInfer.") "best performance, please install FlashInfer.")
self.forward = self.forward_native self.forward = self.forward_native
elif current_platform.is_tpu(): elif current_platform.is_tpu():
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu self.forward = self.forward_tpu
else: else:
self.forward = self.forward_native self.forward = self.forward_native
@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu(
chance of being chosen during final sampling, so we can consider the tie chance of being chosen during final sampling, so we can consider the tie
being broken then. being broken then.
""" """
if k is not None:
logits = apply_top_k_only(logits, k)
if p is not None:
probs = logits.softmax(dim=-1) probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False) probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1) cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one top_p_mask[:, -1] = False # at least one
@ -224,7 +227,7 @@ def apply_top_k_only(
max_top_k = k.max() max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k]. # topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k). # Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows. # Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))

View File

@ -10,7 +10,7 @@ DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0, temperature=-1.0,
min_p=0.0, min_p=0.0,
# strictly disabled for now # strictly disabled for now
# top_k=-1, top_k=0,
# top_p=0.0, # top_p=0.0,
# frequency_penalties=0.0, # frequency_penalties=0.0,
# presence_penalties=0.0, # presence_penalties=0.0,
@ -99,11 +99,13 @@ class TPUSupportedSamplingMetadata:
fill_slice(input_batch.temperature_cpu_tensor, fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"]) DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# fill_slice(input_batch.top_p_cpu_tensor)
# fill_slice(input_batch.top_k_cpu_tensor)
fill_slice(input_batch.min_p_cpu_tensor, fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"]) DEFAULT_SAMPLING_PARAMS["min_p"])
fill_slice(input_batch.top_k_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["top_k"])
# TODO Temporarily disabled until sampling options are enabled
# fill_slice(input_batch.top_p_cpu_tensor,
# DEFAULT_SAMPLING_PARAMS["top_p"])
# Slice persistent device tensors to a fixed pre-compiled padded shape. # Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls( return cls(
@ -112,6 +114,7 @@ class TPUSupportedSamplingMetadata:
all_greedy=input_batch.all_greedy, all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values # TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs], top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs], top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device)) xla_device))

View File

@ -920,14 +920,19 @@ class TPUModelRunner:
device=self.device) device=self.device)
torch._dynamo.mark_dynamic(indices, 0) torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices) self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs)
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs >= min(num_tokens, self.max_num_reqs):
break
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states") self._update_num_xla_graphs("select_hidden_states")
def _precompile_sample_from_hidden(self) -> None: def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.") logger.info("Compiling sampling with different num_reqs.")
start = time.perf_counter() start = time.perf_counter()
hsize = self.model_config.get_hidden_size() hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings: for num_reqs in self.num_reqs_paddings: