diff --git a/vllm/config.py b/vllm/config.py index 08947e39..f86c3272 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -417,8 +417,10 @@ class ModelConfig: from vllm.platforms import current_platform - if self.enable_sleep_mode and not current_platform.is_cuda(): - raise ValueError("Sleep mode is only supported on CUDA devices.") + if (self.enable_sleep_mode + and not current_platform.is_sleep_mode_available()): + raise ValueError( + "Sleep mode is not supported on current platform.") hf_config = get_config(self.hf_config_path or self.model, trust_remote_code, revision, code_revision, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 31a7ffbd..2695da57 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -148,6 +148,9 @@ class Platform: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + def is_sleep_mode_available(self) -> bool: + return self._enum == PlatformEnum.CUDA + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str],