[V1][TPU] Enable Top K (#15489)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
parent
5989f4684d
commit
eb5819b2d9
@ -1,4 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, envs
|
||||
@ -39,3 +41,19 @@ def test_sampler_different(model_name: str):
|
||||
# Unsupported `seed` param.
|
||||
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Batch-case with TopK
|
||||
for B in [4, 16]:
|
||||
p = prompts * B
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=0.1,
|
||||
min_p=0.8,
|
||||
max_tokens=64,
|
||||
# Vary number of ks
|
||||
top_k=random.randint(4, 12)) for _ in range(B)
|
||||
]
|
||||
# Make sure first two reqs have the same K
|
||||
sampling_params[0] = sampling_params[1]
|
||||
output = llm.generate(p, sampling_params)
|
||||
assert output[0].outputs[0].text == output[1].outputs[0].text
|
||||
|
@ -5,7 +5,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
apply_top_k_top_p_tpu)
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
|
||||
TOLERANCE = 1e-6
|
||||
|
||||
|
||||
def test_topk_equivalence_to_native_impl():
|
||||
with torch.device(xm.xla_device()):
|
||||
xm.set_rng_state(seed=33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
||||
|
||||
# Random top-k values between 1 and 10.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ))
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
|
||||
VOCAB_SIZE)
|
||||
|
||||
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
assert torch.allclose(result_native, result_tpu)
|
||||
|
||||
|
||||
def test_topp_result_sums_past_p():
|
||||
with torch.device(xm.xla_device()):
|
||||
xm.set_rng_state(seed=33)
|
||||
|
@ -103,7 +103,6 @@ if TYPE_CHECKING:
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
@ -685,11 +684,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||
|
||||
# If set, disables TPU-specific optimization for top-k & top-p sampling
|
||||
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
|
||||
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
|
||||
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
|
||||
|
||||
# Gap between padding buckets for the forward pass. So we have
|
||||
# 8, we will run forward pass with [16, 24, 32, ...].
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
|
@ -72,14 +72,7 @@ class TopKTopPSampler(nn.Module):
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_tpu():
|
||||
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
|
||||
logger.warning(
|
||||
"TPU-specific optimization for top-k & top-p sampling are "
|
||||
"disabled, falling back to PyTorch-native implementation "
|
||||
"which could be very slow.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_tpu
|
||||
self.forward = self.forward_tpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu(
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
logits = apply_top_k_only(logits, k)
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
@ -224,7 +227,7 @@ def apply_top_k_only(
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
|
@ -10,7 +10,7 @@ DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
# top_k=-1,
|
||||
top_k=0,
|
||||
# top_p=0.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
@ -99,11 +99,13 @@ class TPUSupportedSamplingMetadata:
|
||||
|
||||
fill_slice(input_batch.temperature_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
# TODO Temporarily disabled until sampling options are enabled
|
||||
# fill_slice(input_batch.top_p_cpu_tensor)
|
||||
# fill_slice(input_batch.top_k_cpu_tensor)
|
||||
fill_slice(input_batch.min_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
fill_slice(input_batch.top_k_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
# TODO Temporarily disabled until sampling options are enabled
|
||||
# fill_slice(input_batch.top_p_cpu_tensor,
|
||||
# DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
@ -112,6 +114,7 @@ class TPUSupportedSamplingMetadata:
|
||||
all_greedy=input_batch.all_greedy,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
||||
top_k=None, # input_batch.top_k[:padded_num_reqs],
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device))
|
||||
|
@ -920,14 +920,19 @@ class TPUModelRunner:
|
||||
device=self.device)
|
||||
torch._dynamo.mark_dynamic(indices, 0)
|
||||
self.select_hidden_states(dummy_hidden, indices)
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs)
|
||||
# Requests can't be more than tokens. But do compile for the
|
||||
# next bigger value in case num_tokens uses bucketed padding.
|
||||
if num_reqs >= min(num_tokens, self.max_num_reqs):
|
||||
break
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("select_hidden_states")
|
||||
|
||||
def _precompile_sample_from_hidden(self) -> None:
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
logger.info("Compiling sampling with different num_reqs.")
|
||||
start = time.perf_counter()
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
for num_reqs in self.num_reqs_paddings:
|
||||
|
Loading…
x
Reference in New Issue
Block a user