[Core][Distributed] refactor pynccl (#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
This commit is contained in:
parent
c833101740
commit
208b71bcc1
@ -1,15 +1,15 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
|
||||
init_distributed_environment, with_pynccl_for_all_reduce)
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture_mode, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@ -41,6 +41,9 @@ def worker_fn_wrapper(fn):
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ['LOCAL_RANK']
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_distributed_environment()
|
||||
fn()
|
||||
|
||||
@ -49,11 +52,13 @@ def worker_fn_wrapper(fn):
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
comm = NCCLCommunicator()
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||
comm.all_reduce(tensor)
|
||||
pynccl_comm = PyNcclCommunicator()
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == comm.world_size
|
||||
assert result == pynccl_comm.world_size
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -70,16 +75,17 @@ def multiple_tp_worker_fn():
|
||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||
comm = NCCLCommunicator(group=group, device=device)
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
# two groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
comm.all_reduce(tensor)
|
||||
comm.all_reduce(tensor)
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
else:
|
||||
comm.all_reduce(tensor)
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
|
||||
@ -88,19 +94,16 @@ def multiple_tp_worker_fn():
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_multiple_tp():
|
||||
# this tests pynccl for multiple tp groups, in a standalone way
|
||||
# i.e. call `comm.all_reduce` directly
|
||||
# i.e. call `pynccl_comm.all_reduce` directly
|
||||
distributed_run(multiple_tp_worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def multiple_tp_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
pynccl_utils.init_process_group(
|
||||
group=get_tensor_model_parallel_cpu_group())
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with with_pynccl_for_all_reduce():
|
||||
with graph_capture_mode():
|
||||
# two tp groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm():
|
||||
def worker_fn_with_cudagraph():
|
||||
with torch.no_grad():
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
comm = NCCLCommunicator()
|
||||
pynccl_comm = PyNcclCommunicator()
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
|
||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(graph, stream=comm.stream):
|
||||
with torch.cuda.graph(
|
||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||
enable=True):
|
||||
# operation during the graph capture is recorded but not executed
|
||||
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
||||
comm.all_reduce(a)
|
||||
comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == comm.world_size**0
|
||||
pynccl_comm.all_reduce(a)
|
||||
pynccl_comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**0
|
||||
graph.replay()
|
||||
comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == comm.world_size**1
|
||||
pynccl_comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph():
|
||||
|
||||
|
||||
def test_ncclGetUniqueId():
|
||||
unique_id = ncclGetUniqueId()
|
||||
lib = NCCLLibrary()
|
||||
unique_id = lib.ncclGetUniqueId()
|
||||
# `list(unique_id.internal)` is something like this:
|
||||
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
|
||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -8,7 +9,26 @@ from .parallel_state import (get_cpu_world_group,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
is_pynccl_enabled_for_all_reduce)
|
||||
get_tp_pynccl_communicator)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture_mode():
|
||||
# In graph capture, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the tensor size
|
||||
# is too large, it will fallback to the next available option.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
assert pynccl_comm is not None
|
||||
with pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream()):
|
||||
yield
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
|
||||
@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
if is_pynccl_enabled_for_all_reduce():
|
||||
pynccl_utils.all_reduce(input_)
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
@ -1,26 +1,4 @@
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
@ -28,217 +6,70 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum, ncclUniqueId)
|
||||
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
so_file = find_nccl_library()
|
||||
|
||||
try:
|
||||
# load the library in another process.
|
||||
# if it core dumps, it will not crash the current process
|
||||
nccl_integrity_check(so_file)
|
||||
nccl = ctypes.CDLL(so_file)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"One solution is to download libnccl2 version 2.18 from "
|
||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
||||
"and extract the libnccl.so.2 file. If you already have the "
|
||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
|
||||
_c_ncclGetErrorString = nccl.ncclGetErrorString
|
||||
_c_ncclGetErrorString.restype = ctypes.c_char_p
|
||||
_c_ncclGetErrorString.argtypes = [ncclResult_t]
|
||||
|
||||
|
||||
def NCCL_CHECK(result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = _c_ncclGetErrorString(result)
|
||||
error_str = error_str.decode("utf-8")
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
_c_ncclGetVersion = nccl.ncclGetVersion
|
||||
_c_ncclGetVersion.restype = ctypes.c_int
|
||||
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
|
||||
|
||||
def ncclGetVersion() -> str:
|
||||
version = ctypes.c_int()
|
||||
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
|
||||
# something like 21903 --> "2.19.3"
|
||||
version_str = str(version.value)
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
|
||||
class NcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
|
||||
_c_ncclGetUniqueId.restype = ctypes.c_int
|
||||
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
||||
|
||||
|
||||
def ncclGetUniqueId() -> NcclUniqueId:
|
||||
unique_id = NcclUniqueId()
|
||||
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
_c_ncclCommInitRank = nccl.ncclCommInitRank
|
||||
_c_ncclCommInitRank.restype = ctypes.c_int
|
||||
_c_ncclCommInitRank.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
||||
]
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# udaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument is a pointer
|
||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
||||
_c_ncclAllReduce.restype = ctypes.c_int
|
||||
_c_ncclAllReduce.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
||||
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
]
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
_c_ncclCommDestroy = nccl.ncclCommDestroy
|
||||
_c_ncclCommDestroy.restype = ctypes.c_int
|
||||
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
|
||||
|
||||
|
||||
class NCCLCommunicator:
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the NCCLCommunicator to. If None,
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
assert dist.is_initialized()
|
||||
group = get_cpu_world_group() if group is None else group
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"NCCLCommunicator should be attached to a non-NCCL group.")
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||
self.group = group
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
self.unique_id = ncclGetUniqueId()
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
self.unique_id = NcclUniqueId()
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
@ -246,7 +77,6 @@ class NCCLCommunicator:
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
self.comm = ctypes.c_void_p()
|
||||
if device is None:
|
||||
local_rank = get_local_rank()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -261,15 +91,25 @@ class NCCLCommunicator:
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
NCCL_CHECK(
|
||||
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank))
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
self.all_reduce(torch.zeros(1, device=device))
|
||||
self.stream.synchronize()
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
@ -278,10 +118,32 @@ class NCCLCommunicator:
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
NCCL_CHECK(
|
||||
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
ctypes.c_void_p(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
ctypes.c_void_p(stream.cuda_stream)))
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self,
|
||||
enable: Optional[bool] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
||||
|
@ -1,66 +0,0 @@
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetVersion)
|
||||
except Exception as e:
|
||||
# in non-NVIDIA environments, we can't import the nccl module
|
||||
# e.g. when running on machines with AMD GPUs
|
||||
logger.info("Failed to import NCCL library: %s", e)
|
||||
logger.info("It is expected if you are not running on NVIDIA GPUs.")
|
||||
pass
|
||||
|
||||
comm: Optional["NCCLCommunicator"] = None
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return comm is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
try:
|
||||
assert comm is not None
|
||||
comm.stream = stream
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
|
||||
assert not is_initialized()
|
||||
global comm
|
||||
logger.info("vLLM is using nccl==%s", ncclGetVersion())
|
||||
comm = NCCLCommunicator(group=group)
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
assert comm is not None
|
||||
comm.all_reduce(input_, op)
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
global comm
|
||||
comm = None
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
assert comm is not None
|
||||
return comm.world_size
|
||||
|
||||
|
||||
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
|
||||
return comm
|
258
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
258
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
@ -0,0 +1,258 @@
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("ncclGetVersion", ncclResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("ncclGetUniqueId", ncclResult_t,
|
||||
[ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("ncclCommInitRank", ncclResult_t, [
|
||||
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||
ctypes.c_int
|
||||
]),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclAllReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
# load the library in another process.
|
||||
# if it core dumps, it will not crash the current process
|
||||
nccl_integrity_check(so_file)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"One solution is to download libnccl2 version 2.18 from "
|
||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
||||
"and extract the libnccl.so.2 file. If you already have the "
|
||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id,
|
||||
rank))
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||
]
|
@ -3,10 +3,10 @@
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -14,10 +14,11 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TP_DEVICE_GROUP = None
|
||||
_TP_CPU_GROUP = None
|
||||
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||
_TP_CPU_GROUP: Optional[ProcessGroup] = None
|
||||
_TP_PYNCCL_COMMUNICATOR = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||
|
||||
# when people blindly call `torch.distributed.all_reduce` etc,
|
||||
# it will use this group. It is initialized with the `backend`
|
||||
@ -41,11 +42,16 @@ _CPU_WORLD_GROUP = None
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the
|
||||
# source rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
_PP_GLOBAL_RANKS: Optional[List[int]] = None
|
||||
|
||||
_LOCAL_RANK = -1
|
||||
|
||||
|
||||
def get_tp_pynccl_communicator():
|
||||
global _TP_PYNCCL_COMMUNICATOR
|
||||
return _TP_PYNCCL_COMMUNICATOR
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
global _LOCAL_RANK
|
||||
return _LOCAL_RANK
|
||||
@ -80,10 +86,20 @@ def init_distributed_environment(
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
if local_rank == -1 and distributed_init_method == "env://":
|
||||
if local_rank == -1:
|
||||
# local rank not set, this usually happens in single-node
|
||||
# setting, where we can use rank as local rank
|
||||
if distributed_init_method == "env://":
|
||||
local_rank = envs.LOCAL_RANK
|
||||
else:
|
||||
local_rank = rank
|
||||
global _LOCAL_RANK
|
||||
_LOCAL_RANK = local_rank
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(device=f"cuda:{local_rank}")
|
||||
torch.distributed.all_reduce(data)
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
@ -133,29 +149,36 @@ def initialize_model_parallel(
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR
|
||||
assert _TP_DEVICE_GROUP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
ranks = list(
|
||||
range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size))
|
||||
group = torch.distributed.new_group(ranks, backend=backend)
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if rank in ranks:
|
||||
_TP_DEVICE_GROUP = group
|
||||
_TP_CPU_GROUP = cpu_group
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
||||
group=_TP_CPU_GROUP,
|
||||
device=_LOCAL_RANK,
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
|
||||
global _PP_DEVICE_GROUP
|
||||
global _PP_GLOBAL_RANKS
|
||||
assert _PP_DEVICE_GROUP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||
group = torch.distributed.new_group(ranks, backend=backend)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
_PP_DEVICE_GROUP = group
|
||||
_PP_GLOBAL_RANKS = ranks
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
@ -188,8 +211,7 @@ def ensure_model_parallel_initialized(
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||
return (_TP_DEVICE_GROUP is not None
|
||||
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
||||
return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
|
||||
|
||||
|
||||
def get_cpu_world_group():
|
||||
@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group():
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
|
||||
assert _PP_DEVICE_GROUP is not None, (
|
||||
"pipeline model parallel group is not initialized")
|
||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
return _PP_DEVICE_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank():
|
||||
def get_pipeline_model_parallel_first_rank():
|
||||
"""Return the global rank of the first process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
assert _PP_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
return _PIPELINE_GLOBAL_RANKS[0]
|
||||
return _PP_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_last_rank():
|
||||
"""Return the global rank of the last process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
assert _PP_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
return _PP_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
"""Return the global rank that follows the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
assert _PP_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank():
|
||||
"""Return the global rank that precedes the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
assert _PP_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
@ -295,45 +317,12 @@ def destroy_model_parallel():
|
||||
if _TP_CPU_GROUP:
|
||||
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
|
||||
_TP_CPU_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
||||
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
global _TP_PYNCCL_COMMUNICATOR
|
||||
_TP_PYNCCL_COMMUNICATOR = None
|
||||
|
||||
# Destroy the pynccl states if any.
|
||||
pynccl_utils.destroy_process_group()
|
||||
|
||||
|
||||
# Whether to use pynccl for nccl all reduce.
|
||||
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
|
||||
# is not well supported by CUDA graph.
|
||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_pynccl_for_all_reduce():
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
"""use pynccl instead of torch.distributed for all reduce"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size == 1:
|
||||
# No-op.
|
||||
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
|
||||
yield
|
||||
else:
|
||||
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = True
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
with pynccl_utils.set_pynccl_stream(stream):
|
||||
yield
|
||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = old
|
||||
|
||||
|
||||
def is_pynccl_enabled_for_all_reduce():
|
||||
"""check if pynccl is enabled for all reduce"""
|
||||
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
return _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
||||
global _PP_DEVICE_GROUP
|
||||
if _PP_DEVICE_GROUP:
|
||||
torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
|
||||
_PP_DEVICE_GROUP = None
|
||||
global _PP_GLOBAL_RANKS
|
||||
_PP_GLOBAL_RANKS = None
|
||||
|
@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import time
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
@ -12,9 +11,9 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||
from vllm.distributed.device_communicators import (custom_all_reduce,
|
||||
pynccl_utils)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.distributed.communication_op import graph_capture_mode
|
||||
from vllm.distributed.device_communicators import custom_all_reduce
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -917,10 +916,6 @@ class ModelRunner:
|
||||
Since it is used for decoding-only, it assumes there's only 1 token
|
||||
per sequence in the batch.
|
||||
"""
|
||||
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
||||
# deleted before the CUDA graphs.
|
||||
self.pynccl_backend = pynccl_utils.get_nccl_backend()
|
||||
|
||||
assert not self.model_config.enforce_eager
|
||||
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
||||
"unexpected consequences if the model is not static. To "
|
||||
@ -1046,7 +1041,7 @@ class CUDAGraphRunner:
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with _maybe_pynccl():
|
||||
with graph_capture_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -1061,7 +1056,7 @@ class CUDAGraphRunner:
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
||||
with _maybe_pynccl():
|
||||
with graph_capture_mode():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -1113,16 +1108,6 @@ class CUDAGraphRunner:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_pynccl():
|
||||
if pynccl_utils.is_initialized(
|
||||
) and not custom_all_reduce.is_initialized():
|
||||
with with_pynccl_for_all_reduce():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
|
@ -11,9 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
get_tensor_model_parallel_cpu_group,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
init_custom_ar)
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -306,29 +304,10 @@ def init_worker_distributed_environment(
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_world_size = pynccl_utils.get_world_size()
|
||||
if pynccl_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"pynccl is already initialized but the pynccl world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||
# is 1.
|
||||
# NOTE(kaichao): By default, pynccl is initialized for tp group.
|
||||
pynccl_utils.init_process_group(
|
||||
group=get_tensor_model_parallel_cpu_group())
|
||||
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
if not parallel_config.disable_custom_all_reduce:
|
||||
init_custom_ar()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
|
Loading…
x
Reference in New Issue
Block a user