[TPU][V1] Disable per-request seed/Generator (#16172)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-10 23:05:44 +02:00 committed by GitHub
parent 7cd0bd7212
commit 3cc9af88ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 18 deletions

View File

@ -34,3 +34,8 @@ def test_sampler_different(model_name: str):
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text
with pytest.raises(ValueError):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)

View File

@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, SamplingType
from .interface import Platform, PlatformEnum, _Backend
@ -145,7 +145,10 @@ class TpuPlatform(Platform):
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params,
SamplingParams) and params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
if isinstance(params, SamplingParams):
if params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")

View File

@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata:
# Greedy sampling flag for compiling single xla graph.
all_greedy: bool = True
# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask = None
bad_words_token_ids = None
# Generator not supported by xla
_generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
@property
def generators(self) -> dict[int, torch.Generator]:
# Generator not supported by torch/xla. This field must be immutable.
return self._generators
@classmethod
def from_input_batch(
cls,
@ -109,5 +114,4 @@ class TPUSupportedSamplingMetadata:
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device),
generators=input_batch.generators)
xla_device))

View File

@ -23,7 +23,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
@ -267,11 +266,6 @@ class TPUModelRunner:
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
@ -280,7 +274,7 @@ class TPUModelRunner:
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],