[Hardware] add platform-specific request validation api (#16291)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2025-04-09 21:50:01 +02:00 committed by GitHub
parent fee5b8d37f
commit cb391d85dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 38 additions and 41 deletions

View File

@ -180,7 +180,3 @@ class CpuPlatform(Platform):
Get device specific communicator class for distributed communication.
"""
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -308,10 +308,6 @@ class CudaPlatformBase(Platform):
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True
@classmethod
def supports_structured_output(cls) -> bool:
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
return True

View File

@ -92,7 +92,3 @@ class HpuPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import enum
import platform
import random
@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
from vllm.inputs import PromptType
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
LoRARequest = None
PoolingParams = None
SamplingParams = None
FlexibleArgumentParser = None
logger = init_logger(__name__)
@ -379,13 +385,6 @@ class Platform:
"""
return False
@classmethod
def supports_structured_output(cls) -> bool:
"""
Returns whether the current platform can support structured output.
"""
return False
@classmethod
def use_custom_allreduce(cls) -> bool:
"""
@ -393,6 +392,14 @@ class Platform:
"""
return False
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -67,7 +67,3 @@ class NeuronPlatform(Platform):
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -303,10 +303,6 @@ class RocmPlatform(Platform):
# V1 support on AMD gpus is experimental
return True
@classmethod
def supports_structured_output(cls) -> bool:
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series

View File

@ -1,19 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
import vllm.envs as envs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
else:
ModelConfig = None
VllmConfig = None
LoRARequest = None
PoolingParams = None
SamplingParams = None
logger = init_logger(__name__)
@ -135,6 +142,13 @@ class TpuPlatform(Platform):
return True
@classmethod
def supports_structured_output(cls) -> bool:
# Structured output is not supported on TPU.
return False
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params,
SamplingParams) and params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")

View File

@ -140,7 +140,3 @@ class XPUPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -141,11 +141,6 @@ class Processor:
else:
params.guided_decoding.backend = engine_level_backend
from vllm.platforms import current_platform
if not current_platform.supports_structured_output():
raise ValueError("Structured output is not supported on "
f"{current_platform.device_name}.")
# Request content validation
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
@ -187,6 +182,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
)
self._validate_lora(lora_request)
self._validate_params(params)
if priority != 0: