[V1][Structured Output] Add supports_structured_output() method to Platform (#16148)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen 2025-04-07 19:06:24 +08:00 committed by GitHub
parent 7c80368710
commit e9ba99f296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 41 additions and 3 deletions

View File

@ -180,3 +180,7 @@ class CpuPlatform(Platform):
Get device specific communicator class for distributed communication.
"""
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -308,6 +308,10 @@ class CudaPlatformBase(Platform):
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True
@classmethod
def supports_structured_output(cls) -> bool:
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
return True

View File

@ -92,3 +92,7 @@ class HpuPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -379,6 +379,13 @@ class Platform:
"""
return False
@classmethod
def supports_structured_output(cls) -> bool:
"""
Returns whether the current platform can support structured output.
"""
return False
@classmethod
def use_custom_allreduce(cls) -> bool:
"""

View File

@ -67,3 +67,7 @@ class NeuronPlatform(Platform):
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -303,6 +303,10 @@ class RocmPlatform(Platform):
# V1 support on AMD gpus is experimental
return True
@classmethod
def supports_structured_output(cls) -> bool:
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series

View File

@ -133,3 +133,8 @@ class TpuPlatform(Platform):
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on TPU is experimental
return True
@classmethod
def supports_structured_output(cls) -> bool:
# Structured output is not supported on TPU.
return False

View File

@ -140,3 +140,7 @@ class XPUPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True

View File

@ -136,9 +136,11 @@ class Processor:
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend
import vllm.platforms
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
from vllm.platforms import current_platform
if not current_platform.supports_structured_output():
raise ValueError("Structured output is not supported on "
f"{current_platform.device_name}.")
# Request content validation
if engine_level_backend.startswith("xgrammar"):