diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 3c8aecc0..70553354 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -180,7 +180,3 @@ class CpuPlatform(Platform): Get device specific communicator class for distributed communication. """ return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa - - @classmethod - def supports_structured_output(cls) -> bool: - return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 053cf74e..0576022b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -308,10 +308,6 @@ class CudaPlatformBase(Platform): def supports_v1(cls, model_config: ModelConfig) -> bool: return True - @classmethod - def supports_structured_output(cls) -> bool: - return True - @classmethod def use_custom_allreduce(cls) -> bool: return True diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index f011f140..4c842b52 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -92,7 +92,3 @@ class HpuPlatform(Platform): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa - - @classmethod - def supports_structured_output(cls) -> bool: - return True diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f788d90b..31a7ffbd 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - import enum import platform import random @@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch +from vllm.inputs import PromptType from vllm.logger import init_logger if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig + from vllm.lora.request import LoRARequest + from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser else: ModelConfig = None VllmConfig = None + LoRARequest = None + PoolingParams = None + SamplingParams = None FlexibleArgumentParser = None logger = init_logger(__name__) @@ -379,13 +385,6 @@ class Platform: """ return False - @classmethod - def supports_structured_output(cls) -> bool: - """ - Returns whether the current platform can support structured output. - """ - return False - @classmethod def use_custom_allreduce(cls) -> bool: """ @@ -393,6 +392,14 @@ class Platform: """ return False + @classmethod + def validate_request( + cls, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + ) -> None: + """Raises if this request is unsupported on this platform""" + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 93657881..c1f426e5 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -67,7 +67,3 @@ class NeuronPlatform(Platform): @classmethod def use_all_gather(cls) -> bool: return True - - @classmethod - def supports_structured_output(cls) -> bool: - return True diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a2fbf416..d18b7c26 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -303,10 +303,6 @@ class RocmPlatform(Platform): # V1 support on AMD gpus is experimental return True - @classmethod - def supports_structured_output(cls) -> bool: - return True - @classmethod def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index eeadb4a7..d5848424 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,19 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch import vllm.envs as envs +from vllm.inputs import PromptType from vllm.logger import init_logger from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig + from vllm.lora.request import LoRARequest + from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams else: ModelConfig = None VllmConfig = None + LoRARequest = None + PoolingParams = None + SamplingParams = None logger = init_logger(__name__) @@ -135,6 +142,13 @@ class TpuPlatform(Platform): return True @classmethod - def supports_structured_output(cls) -> bool: - # Structured output is not supported on TPU. - return False + def validate_request( + cls, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + ) -> None: + """Raises if this request is unsupported on this platform""" + if isinstance(params, + SamplingParams) and params.guided_decoding is not None: + raise ValueError("Structured output is not supported on " + f"{cls.device_name}.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c4bd6393..225e756c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -140,7 +140,3 @@ class XPUPlatform(Platform): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa - - @classmethod - def supports_structured_output(cls) -> bool: - return True diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5f9c8ea4..2525b10a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -141,11 +141,6 @@ class Processor: else: params.guided_decoding.backend = engine_level_backend - from vllm.platforms import current_platform - if not current_platform.supports_structured_output(): - raise ValueError("Structured output is not supported on " - f"{current_platform.device_name}.") - # Request content validation if engine_level_backend.startswith("xgrammar"): # xgrammar with no fallback @@ -187,6 +182,11 @@ class Processor: # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + ) self._validate_lora(lora_request) self._validate_params(params) if priority != 0: