From eb5819b2d9ff4e5a019de97c333bbedf2a2def1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Thu, 17 Apr 2025 20:18:11 +0200 Subject: [PATCH] [V1][TPU] Enable Top K (#15489) Signed-off-by: NickLucche Signed-off-by: Hyesoo Yang Co-authored-by: Hyesoo Yang --- tests/v1/tpu/test_sampler.py | 18 +++++++++++++++++ tests/v1/tpu/test_topk_topp_sampler.py | 22 +++++++++++++++++++- vllm/envs.py | 6 ------ vllm/v1/sample/ops/topk_topp_sampler.py | 27 ++++++++++++++----------- vllm/v1/sample/tpu/metadata.py | 13 +++++++----- vllm/v1/worker/tpu_model_runner.py | 9 +++++++-- 6 files changed, 69 insertions(+), 26 deletions(-) diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 0147da53..74ad8140 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest from vllm import LLM, envs @@ -39,3 +41,19 @@ def test_sampler_different(model_name: str): # Unsupported `seed` param. sampling_params = SamplingParams(temperature=0.3, seed=42) output2 = llm.generate(prompts, sampling_params) + + # Batch-case with TopK + for B in [4, 16]: + p = prompts * B + sampling_params = [ + SamplingParams( + temperature=0.1, + min_p=0.8, + max_tokens=64, + # Vary number of ks + top_k=random.randint(4, 12)) for _ in range(B) + ] + # Make sure first two reqs have the same K + sampling_params[0] = sampling_params[1] + output = llm.generate(p, sampling_params) + assert output[0].outputs[0].text == output[1].outputs[0].text diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index dce0303e..ff9217f8 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -5,7 +5,8 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_tpu) if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) @@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024 TOLERANCE = 1e-6 +def test_topk_equivalence_to_native_impl(): + with torch.device(xm.xla_device()): + xm.set_rng_state(seed=33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + + # Random top-k values between 1 and 10. + k = torch.randint(1, 10, (BATCH_SIZE, )) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), + VOCAB_SIZE) + + result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) + + result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + assert torch.allclose(result_native, result_tpu) + + def test_topp_result_sums_past_p(): with torch.device(xm.xla_device()): xm.set_rng_state(seed=33) diff --git a/vllm/envs.py b/vllm/envs.py index d32968c3..76b5a4d8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -103,7 +103,6 @@ if TYPE_CHECKING: VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False - VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 @@ -685,11 +684,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", - # If set, disables TPU-specific optimization for top-k & top-p sampling - "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": - lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) - if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, - # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f69623ed..745b81de 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -72,14 +72,7 @@ class TopKTopPSampler(nn.Module): "best performance, please install FlashInfer.") self.forward = self.forward_native elif current_platform.is_tpu(): - if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: - logger.warning( - "TPU-specific optimization for top-k & top-p sampling are " - "disabled, falling back to PyTorch-native implementation " - "which could be very slow.") - self.forward = self.forward_native - else: - self.forward = self.forward_tpu + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu( chance of being chosen during final sampling, so we can consider the tie being broken then. """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + if k is not None: - logits = apply_top_k_only(logits, k) + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) if p is not None: - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) cumprob = torch.cumsum(probs_sort, dim=-1) top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = False # at least one @@ -224,7 +227,7 @@ def apply_top_k_only( max_top_k = k.max() # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). - k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) + k_index = k.sub_(1).unsqueeze(1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 3950fda3..917d8baf 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -10,7 +10,7 @@ DEFAULT_SAMPLING_PARAMS = dict( temperature=-1.0, min_p=0.0, # strictly disabled for now - # top_k=-1, + top_k=0, # top_p=0.0, # frequency_penalties=0.0, # presence_penalties=0.0, @@ -99,11 +99,13 @@ class TPUSupportedSamplingMetadata: fill_slice(input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"]) - # TODO Temporarily disabled until sampling options are enabled - # fill_slice(input_batch.top_p_cpu_tensor) - # fill_slice(input_batch.top_k_cpu_tensor) fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_k"]) + # TODO Temporarily disabled until sampling options are enabled + # fill_slice(input_batch.top_p_cpu_tensor, + # DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( @@ -112,6 +114,7 @@ class TPUSupportedSamplingMetadata: all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values top_p=None, # input_batch.top_p[:padded_num_reqs], - top_k=None, # input_batch.top_k[:padded_num_reqs], + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( + xla_device), min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( xla_device)) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b66cd8d2..f31454ab 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -920,14 +920,19 @@ class TPUModelRunner: device=self.device) torch._dynamo.mark_dynamic(indices, 0) self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d", num_tokens) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, + num_reqs) + # Requests can't be more than tokens. But do compile for the + # next bigger value in case num_tokens uses bucketed padding. + if num_reqs >= min(num_tokens, self.max_num_reqs): + break xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) self._update_num_xla_graphs("select_hidden_states") def _precompile_sample_from_hidden(self) -> None: - logger.info("Compiling sampling with different input shapes.") + logger.info("Compiling sampling with different num_reqs.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_reqs in self.num_reqs_paddings: