diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index f535abed..0147da53 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -34,3 +34,8 @@ def test_sampler_different(model_name: str): sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) output2 = llm.generate(prompts, sampling_params) assert output[0].outputs[0].text != output2[0].outputs[0].text + + with pytest.raises(ValueError): + # Unsupported `seed` param. + sampling_params = SamplingParams(temperature=0.3, seed=42) + output2 = llm.generate(prompts, sampling_params) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 61e84a6d..ada599c2 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -7,7 +7,7 @@ import torch import vllm.envs as envs from vllm.inputs import PromptType from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from .interface import Platform, PlatformEnum, _Backend @@ -145,7 +145,10 @@ class TpuPlatform(Platform): params: Union[SamplingParams, PoolingParams], ) -> None: """Raises if this request is unsupported on this platform""" - if isinstance(params, - SamplingParams) and params.guided_decoding is not None: - raise ValueError("Structured output is not supported on " - f"{cls.device_name}.") + if isinstance(params, SamplingParams): + if params.guided_decoding is not None: + raise ValueError("Structured output is not supported on " + f"{cls.device_name}.") + if params.sampling_type == SamplingType.RANDOM_SEED: + raise ValueError( + "Torch XLA does not support per-request seed.") diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 10995d67..3950fda3 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata: # Greedy sampling flag for compiling single xla graph. all_greedy: bool = True - # Generator not supported by xla - generators: dict[int, - torch.Generator] = field(default_factory=lambda: dict()) - # unsupported, you need to return an extra tensor of static size BxV max_num_logprobs = None @@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata: allowed_token_ids_mask = None bad_words_token_ids = None + # Generator not supported by xla + _generators: dict[int, + torch.Generator] = field(default_factory=lambda: dict()) + + @property + def generators(self) -> dict[int, torch.Generator]: + # Generator not supported by torch/xla. This field must be immutable. + return self._generators + @classmethod def from_input_batch( cls, @@ -109,5 +114,4 @@ class TPUSupportedSamplingMetadata: top_p=None, # input_batch.top_p[:padded_num_reqs], top_k=None, # input_batch.top_k[:padded_num_reqs], min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - generators=input_batch.generators) + xla_device)) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e6c5a899..69251d8b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -23,7 +23,6 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, @@ -267,11 +266,6 @@ class TPUModelRunner: for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None self.requests[req_id] = CachedRequestState( req_id=req_id, @@ -280,7 +274,7 @@ class TPUModelRunner: mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, - generator=generator, + generator=None, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[],