Add cuda_device_count_stateless
(#5473)
This commit is contained in:
parent
e38042d4af
commit
50eed24d25
@ -48,6 +48,7 @@ steps:
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
- label: Distributed Tests (Multiple Groups)
|
||||
#mirror_hardwares: [amd]
|
||||
|
@ -1,8 +1,6 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
@ -22,7 +20,7 @@ from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.utils import cuda_device_count_stateless, is_cpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -539,15 +537,4 @@ def num_gpus_available():
|
||||
"""Get number of GPUs without initializing the CUDA context
|
||||
in current process."""
|
||||
|
||||
try:
|
||||
out = subprocess.run([
|
||||
sys.executable, "-c",
|
||||
"import torch; print(torch.cuda.device_count())"
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
text=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning("Failed to get number of GPUs.", exc_info=e)
|
||||
return 0
|
||||
return int(out.stdout.strip())
|
||||
return cuda_device_count_stateless()
|
||||
|
31
tests/distributed/test_utils.py
Normal file
31
tests/distributed/test_utils.py
Normal file
@ -0,0 +1,31 @@
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
|
||||
@ray.remote
|
||||
class _CUDADeviceCountStatelessTestActor():
|
||||
|
||||
def get_count(self):
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
||||
|
||||
def get_cuda_visible_devices(self):
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
|
||||
def test_cuda_device_count_stateless():
|
||||
"""Test that cuda_device_count_stateless changes return value if
|
||||
CUDA_VISIBLE_DEVICES is changed."""
|
||||
|
||||
actor = _CUDADeviceCountStatelessTestActor.options(num_gpus=2).remote()
|
||||
assert ray.get(actor.get_cuda_visible_devices.remote()) == "0,1"
|
||||
assert ray.get(actor.get_count.remote()) == 2
|
||||
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
||||
assert ray.get(actor.get_count.remote()) == 1
|
||||
ray.get(actor.set_cuda_visible_devices.remote(""))
|
||||
assert ray.get(actor.get_count.remote()) == 0
|
@ -11,7 +11,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu
|
||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||
is_hip, is_neuron, is_tpu)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -605,12 +606,11 @@ class ParallelConfig:
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
from torch.cuda import device_count
|
||||
|
||||
from vllm.executor import ray_utils
|
||||
backend = "mp"
|
||||
ray_found = ray_utils.ray is not None
|
||||
if device_count() < self.world_size:
|
||||
if cuda_device_count_stateless() < self.world_size:
|
||||
if not ray_found:
|
||||
raise ValueError("Unable to load Ray which is "
|
||||
"required for multi-node inference")
|
||||
|
@ -11,6 +11,7 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
@ -144,7 +145,7 @@ class CustomAllreduce:
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
|
@ -12,6 +12,7 @@ import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -152,7 +153,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = torch.cuda.device_count()
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
@ -9,7 +9,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id, make_async)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -33,8 +34,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
from torch.cuda import device_count
|
||||
assert world_size <= device_count(), (
|
||||
assert world_size <= cuda_device_count_stateless(), (
|
||||
"please set tensor_parallel_size to less than max local gpu count")
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
|
@ -693,3 +693,38 @@ def deprecate_kwargs(
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(
|
||||
cuda_visible_devices: Optional[str] = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
# bypass _device_count_nvml() if rocm (not supported)
|
||||
nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
|
||||
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
|
||||
return r
|
||||
|
||||
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
This should be used instead of torch.cuda.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
Loading…
x
Reference in New Issue
Block a user