[TPU][V1] Disable per-request seed/Generator (#16172)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
7cd0bd7212
commit
3cc9af88ff
@ -34,3 +34,8 @@ def test_sampler_different(model_name: str):
|
||||
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
assert output[0].outputs[0].text != output2[0].outputs[0].text
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# Unsupported `seed` param.
|
||||
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -145,7 +145,10 @@ class TpuPlatform(Platform):
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
if isinstance(params,
|
||||
SamplingParams) and params.guided_decoding is not None:
|
||||
raise ValueError("Structured output is not supported on "
|
||||
f"{cls.device_name}.")
|
||||
if isinstance(params, SamplingParams):
|
||||
if params.guided_decoding is not None:
|
||||
raise ValueError("Structured output is not supported on "
|
||||
f"{cls.device_name}.")
|
||||
if params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
raise ValueError(
|
||||
"Torch XLA does not support per-request seed.")
|
||||
|
@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata:
|
||||
# Greedy sampling flag for compiling single xla graph.
|
||||
all_greedy: bool = True
|
||||
|
||||
# Generator not supported by xla
|
||||
generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
# unsupported, you need to return an extra tensor of static size BxV
|
||||
max_num_logprobs = None
|
||||
|
||||
@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata:
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
|
||||
# Generator not supported by xla
|
||||
_generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
@property
|
||||
def generators(self) -> dict[int, torch.Generator]:
|
||||
# Generator not supported by torch/xla. This field must be immutable.
|
||||
return self._generators
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls,
|
||||
@ -109,5 +114,4 @@ class TPUSupportedSamplingMetadata:
|
||||
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
||||
top_k=None, # input_batch.top_k[:padded_num_reqs],
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
generators=input_batch.generators)
|
||||
xla_device))
|
||||
|
@ -23,7 +23,6 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
@ -267,11 +266,6 @@ class TPUModelRunner:
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
@ -280,7 +274,7 @@ class TPUModelRunner:
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
generator=None,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
|
Loading…
x
Reference in New Issue
Block a user