[Refactor]A simple device-related refactor (#11163)
Signed-off-by: noemotiovon <noemotiovon@gmail.com> Co-authored-by: noemotiovon <noemotiovon@gmail.com>
This commit is contained in:
parent
969da7d70b
commit
d1fa714cb1
@ -98,3 +98,8 @@ class CpuPlatform(Platform):
|
|||||||
"vllm.worker.cpu_worker.CPUWorker"
|
"vllm.worker.cpu_worker.CPUWorker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_pin_memory_available(cls) -> bool:
|
||||||
|
logger.warning("Pin memory is not supported on CPU.")
|
||||||
|
return False
|
||||||
|
@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum, _Backend
|
from .interface import Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -9,6 +11,8 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HpuPlatform(Platform):
|
class HpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.HPU
|
_enum = PlatformEnum.HPU
|
||||||
@ -43,3 +47,8 @@ class HpuPlatform(Platform):
|
|||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
|
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_pin_memory_available(cls):
|
||||||
|
logger.warning("Pin memory is not supported on HPU.")
|
||||||
|
return False
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
from platform import uname
|
||||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -16,6 +17,11 @@ else:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def in_wsl() -> bool:
|
||||||
|
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||||
|
return "microsoft" in " ".join(uname()).lower()
|
||||||
|
|
||||||
|
|
||||||
class _Backend(enum.Enum):
|
class _Backend(enum.Enum):
|
||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||||
@ -221,6 +227,17 @@ class Platform:
|
|||||||
|
|
||||||
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
|
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_pin_memory_available(cls) -> bool:
|
||||||
|
"""Checks whether pin memory is available on the current platform."""
|
||||||
|
if in_wsl():
|
||||||
|
# Pinning memory in WSL is not supported.
|
||||||
|
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
||||||
|
logger.warning("Using 'pin_memory=False' as WSL is detected. "
|
||||||
|
"This may slow down the performance.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -7,6 +9,8 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NeuronPlatform(Platform):
|
class NeuronPlatform(Platform):
|
||||||
_enum = PlatformEnum.NEURON
|
_enum = PlatformEnum.NEURON
|
||||||
@ -28,3 +32,8 @@ class NeuronPlatform(Platform):
|
|||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
parallel_config.worker_cls = \
|
parallel_config.worker_cls = \
|
||||||
"vllm.worker.neuron_worker.NeuronWorker"
|
"vllm.worker.neuron_worker.NeuronWorker"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_pin_memory_available(cls) -> bool:
|
||||||
|
logger.warning("Pin memory is not supported on Neuron.")
|
||||||
|
return False
|
||||||
|
@ -34,7 +34,7 @@ class OpenVinoPlatform(Platform):
|
|||||||
return _Backend.OPENVINO
|
return _Backend.OPENVINO
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(self, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return "openvino"
|
return "openvino"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -42,19 +42,19 @@ class OpenVinoPlatform(Platform):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(self):
|
def inference_mode(cls):
|
||||||
return torch.inference_mode(mode=True)
|
return torch.inference_mode(mode=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_openvino_cpu(self) -> bool:
|
def is_openvino_cpu(cls) -> bool:
|
||||||
return "CPU" in envs.VLLM_OPENVINO_DEVICE
|
return "CPU" in envs.VLLM_OPENVINO_DEVICE
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_openvino_gpu(self) -> bool:
|
def is_openvino_gpu(cls) -> bool:
|
||||||
return "GPU" in envs.VLLM_OPENVINO_DEVICE
|
return "GPU" in envs.VLLM_OPENVINO_DEVICE
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_pin_memory_available(self) -> bool:
|
def is_pin_memory_available(cls) -> bool:
|
||||||
logger.warning("Pin memory is not supported on OpenViNO.")
|
logger.warning("Pin memory is not supported on OpenViNO.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -78,3 +78,8 @@ class XPUPlatform(Platform):
|
|||||||
parallel_config.distributed_executor_backend = "ray"
|
parallel_config.distributed_executor_backend = "ray"
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_pin_memory_available(cls):
|
||||||
|
logger.warning("Pin memory is not supported on XPU.")
|
||||||
|
return False
|
||||||
|
@ -24,7 +24,6 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from platform import uname
|
|
||||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||||
Dict, Generic, Hashable, List, Literal, Optional,
|
Dict, Generic, Hashable, List, Literal, Optional,
|
||||||
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
|
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
|
||||||
@ -344,12 +343,6 @@ def random_uuid() -> str:
|
|||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
|
||||||
def in_wsl() -> bool:
|
|
||||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
|
||||||
return "microsoft" in " ".join(uname()).lower()
|
|
||||||
|
|
||||||
|
|
||||||
def make_async(
|
def make_async(
|
||||||
func: Callable[P, T],
|
func: Callable[P, T],
|
||||||
executor: Optional[concurrent.futures.Executor] = None
|
executor: Optional[concurrent.futures.Executor] = None
|
||||||
@ -729,25 +722,7 @@ def print_warning_once(msg: str) -> None:
|
|||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def is_pin_memory_available() -> bool:
|
def is_pin_memory_available() -> bool:
|
||||||
|
return current_platform.is_pin_memory_available()
|
||||||
if in_wsl():
|
|
||||||
# Pinning memory in WSL is not supported.
|
|
||||||
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
|
||||||
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
|
||||||
"This may slow down the performance.")
|
|
||||||
return False
|
|
||||||
elif current_platform.is_xpu():
|
|
||||||
print_warning_once("Pin memory is not supported on XPU.")
|
|
||||||
return False
|
|
||||||
elif current_platform.is_neuron():
|
|
||||||
print_warning_once("Pin memory is not supported on Neuron.")
|
|
||||||
return False
|
|
||||||
elif current_platform.is_hpu():
|
|
||||||
print_warning_once("Pin memory is not supported on HPU.")
|
|
||||||
return False
|
|
||||||
elif current_platform.is_cpu() or current_platform.is_openvino():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceMemoryProfiler:
|
class DeviceMemoryProfiler:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user