Add cuda_device_count_stateless (#5473)

This commit is contained in:
Antoni Baum 2024-06-13 16:06:49 -07:00 committed by GitHub
parent e38042d4af
commit 50eed24d25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 79 additions and 23 deletions

View File

@ -48,6 +48,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py - pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
- label: Distributed Tests (Multiple Groups) - label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd] #mirror_hardwares: [amd]

View File

@ -1,8 +1,6 @@
import contextlib import contextlib
import gc import gc
import os import os
import subprocess
import sys
from typing import Any, Dict, List, Optional, Tuple, TypeVar from typing import Any, Dict, List, Optional, Tuple, TypeVar
import pytest import pytest
@ -22,7 +20,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MultiModalData from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu from vllm.utils import cuda_device_count_stateless, is_cpu
logger = init_logger(__name__) logger = init_logger(__name__)
@ -539,15 +537,4 @@ def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context """Get number of GPUs without initializing the CUDA context
in current process.""" in current process."""
try: return cuda_device_count_stateless()
out = subprocess.run([
sys.executable, "-c",
"import torch; print(torch.cuda.device_count())"
],
capture_output=True,
check=True,
text=True)
except subprocess.CalledProcessError as e:
logger.warning("Failed to get number of GPUs.", exc_info=e)
return 0
return int(out.stdout.strip())

View File

@ -0,0 +1,31 @@
import os
import ray
from vllm.utils import cuda_device_count_stateless
@ray.remote
class _CUDADeviceCountStatelessTestActor():
def get_count(self):
return cuda_device_count_stateless()
def set_cuda_visible_devices(self, cuda_visible_devices: str):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
def get_cuda_visible_devices(self):
return os.environ["CUDA_VISIBLE_DEVICES"]
def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed."""
actor = _CUDADeviceCountStatelessTestActor.options(num_gpus=2).remote()
assert ray.get(actor.get_cuda_visible_devices.remote()) == "0,1"
assert ray.get(actor.get_count.remote()) == 2
ray.get(actor.set_cuda_visible_devices.remote("0"))
assert ray.get(actor.get_count.remote()) == 1
ray.get(actor.set_cuda_visible_devices.remote(""))
assert ray.get(actor.get_count.remote()) == 0

View File

@ -11,7 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
@ -605,12 +606,11 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
from torch.cuda import device_count
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray is not None ray_found = ray_utils.ray is not None
if device_count() < self.world_size: if cuda_device_count_stateless() < self.world_size:
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
"required for multi-node inference") "required for multi-node inference")

View File

@ -11,6 +11,7 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check) gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
try: try:
import pynvml import pynvml
@ -144,7 +145,7 @@ class CustomAllreduce:
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(torch.cuda.device_count())) device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index] physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], tensor = torch.tensor([physical_device_id],

View File

@ -12,6 +12,7 @@ import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
@ -152,7 +153,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
is_distributed = dist.is_initialized() is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count() num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))

View File

@ -9,7 +9,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (cuda_device_count_stateless,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -33,8 +34,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Disable torch async compiling which won't work with daemonic processes # Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
from torch.cuda import device_count assert world_size <= cuda_device_count_stateless(), (
assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count") "please set tensor_parallel_size to less than max local gpu count")
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(

View File

@ -693,3 +693,38 @@ def deprecate_kwargs(
return inner # type: ignore return inner # type: ignore
return wrapper return wrapper
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(
cuda_visible_devices: Optional[str] = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
import torch.version
if not torch.cuda._is_compiled():
return 0
# bypass _device_count_nvml() if rocm (not supported)
nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
return r
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)