[Refactor]A simple device-related refactor (#11163)

Signed-off-by: noemotiovon <noemotiovon@gmail.com>
Co-authored-by: noemotiovon <noemotiovon@gmail.com>
This commit is contained in:
Chenguang Li 2024-12-13 21:39:00 +08:00 committed by GitHub
parent 969da7d70b
commit d1fa714cb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 51 additions and 31 deletions

View File

@ -98,3 +98,8 @@ class CpuPlatform(Platform):
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
return False

View File

@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
@ -9,6 +11,8 @@ if TYPE_CHECKING:
else:
VllmConfig = None
logger = init_logger(__name__)
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
@ -43,3 +47,8 @@ class HpuPlatform(Platform):
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
return False

View File

@ -1,6 +1,7 @@
import enum
import platform
import random
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np
@ -16,6 +17,11 @@ else:
logger = init_logger(__name__)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
@ -221,6 +227,17 @@ class Platform:
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
@classmethod
def is_pin_memory_available(cls) -> bool:
"""Checks whether pin memory is available on the current platform."""
if in_wsl():
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
return False
return True
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Optional
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
@ -7,6 +9,8 @@ if TYPE_CHECKING:
else:
VllmConfig = None
logger = init_logger(__name__)
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
@ -28,3 +32,8 @@ class NeuronPlatform(Platform):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False

View File

@ -34,7 +34,7 @@ class OpenVinoPlatform(Platform):
return _Backend.OPENVINO
@classmethod
def get_device_name(self, device_id: int = 0) -> str:
def get_device_name(cls, device_id: int = 0) -> str:
return "openvino"
@classmethod
@ -42,19 +42,19 @@ class OpenVinoPlatform(Platform):
return False
@classmethod
def inference_mode(self):
def inference_mode(cls):
return torch.inference_mode(mode=True)
@classmethod
def is_openvino_cpu(self) -> bool:
def is_openvino_cpu(cls) -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
@classmethod
def is_openvino_gpu(self) -> bool:
def is_openvino_gpu(cls) -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
@classmethod
def is_pin_memory_available(self) -> bool:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on OpenViNO.")
return False

View File

@ -78,3 +78,8 @@ class XPUPlatform(Platform):
parallel_config.distributed_executor_backend = "ray"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.")
return False

View File

@ -24,7 +24,6 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generic, Hashable, List, Literal, Optional,
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
@ -344,12 +343,6 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
@lru_cache(maxsize=None)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()
def make_async(
func: Callable[P, T],
executor: Optional[concurrent.futures.Executor] = None
@ -729,25 +722,7 @@ def print_warning_once(msg: str) -> None:
@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:
if in_wsl():
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
return False
elif current_platform.is_xpu():
print_warning_once("Pin memory is not supported on XPU.")
return False
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
elif current_platform.is_cpu() or current_platform.is_openvino():
return False
return True
return current_platform.is_pin_memory_available()
class DeviceMemoryProfiler: