[Hardware] add platform-specific request validation api (#16291)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
fee5b8d37f
commit
cb391d85dc
@ -180,7 +180,3 @@ class CpuPlatform(Platform):
|
|||||||
Get device specific communicator class for distributed communication.
|
Get device specific communicator class for distributed communication.
|
||||||
"""
|
"""
|
||||||
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
@ -308,10 +308,6 @@ class CudaPlatformBase(Platform):
|
|||||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def use_custom_allreduce(cls) -> bool:
|
def use_custom_allreduce(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -92,7 +92,3 @@ class HpuPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_device_communicator_cls(cls) -> str:
|
def get_device_communicator_cls(cls) -> str:
|
||||||
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
|
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
else:
|
else:
|
||||||
ModelConfig = None
|
ModelConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
LoRARequest = None
|
||||||
|
PoolingParams = None
|
||||||
|
SamplingParams = None
|
||||||
FlexibleArgumentParser = None
|
FlexibleArgumentParser = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -379,13 +385,6 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
"""
|
|
||||||
Returns whether the current platform can support structured output.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def use_custom_allreduce(cls) -> bool:
|
def use_custom_allreduce(cls) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -393,6 +392,14 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_request(
|
||||||
|
cls,
|
||||||
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
) -> None:
|
||||||
|
"""Raises if this request is unsupported on this platform"""
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
@ -67,7 +67,3 @@ class NeuronPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def use_all_gather(cls) -> bool:
|
def use_all_gather(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
@ -303,10 +303,6 @@ class RocmPlatform(Platform):
|
|||||||
# V1 support on AMD gpus is experimental
|
# V1 support on AMD gpus is experimental
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def use_custom_allreduce(cls) -> bool:
|
def use_custom_allreduce(cls) -> bool:
|
||||||
# We only enable custom allreduce for MI300 series
|
# We only enable custom allreduce for MI300 series
|
||||||
|
@ -1,19 +1,26 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum, _Backend
|
from .interface import Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
else:
|
else:
|
||||||
ModelConfig = None
|
ModelConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
LoRARequest = None
|
||||||
|
PoolingParams = None
|
||||||
|
SamplingParams = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -135,6 +142,13 @@ class TpuPlatform(Platform):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_structured_output(cls) -> bool:
|
def validate_request(
|
||||||
# Structured output is not supported on TPU.
|
cls,
|
||||||
return False
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
) -> None:
|
||||||
|
"""Raises if this request is unsupported on this platform"""
|
||||||
|
if isinstance(params,
|
||||||
|
SamplingParams) and params.guided_decoding is not None:
|
||||||
|
raise ValueError("Structured output is not supported on "
|
||||||
|
f"{cls.device_name}.")
|
||||||
|
@ -140,7 +140,3 @@ class XPUPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_device_communicator_cls(cls) -> str:
|
def get_device_communicator_cls(cls) -> str:
|
||||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def supports_structured_output(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
@ -141,11 +141,6 @@ class Processor:
|
|||||||
else:
|
else:
|
||||||
params.guided_decoding.backend = engine_level_backend
|
params.guided_decoding.backend = engine_level_backend
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
if not current_platform.supports_structured_output():
|
|
||||||
raise ValueError("Structured output is not supported on "
|
|
||||||
f"{current_platform.device_name}.")
|
|
||||||
|
|
||||||
# Request content validation
|
# Request content validation
|
||||||
if engine_level_backend.startswith("xgrammar"):
|
if engine_level_backend.startswith("xgrammar"):
|
||||||
# xgrammar with no fallback
|
# xgrammar with no fallback
|
||||||
@ -187,6 +182,11 @@ class Processor:
|
|||||||
# TODO(woosuk): Support pooling models.
|
# TODO(woosuk): Support pooling models.
|
||||||
# TODO(woosuk): Support encoder-decoder models.
|
# TODO(woosuk): Support encoder-decoder models.
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
current_platform.validate_request(
|
||||||
|
prompt=prompt,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
self._validate_params(params)
|
self._validate_params(params)
|
||||||
if priority != 0:
|
if priority != 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user