[Core][Distributed] refactor pynccl (#4591)

[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
This commit is contained in:
youkaichao 2024-05-09 19:48:43 -07:00 committed by GitHub
parent c833101740
commit 208b71bcc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 466 additions and 433 deletions

View File

@ -1,15 +1,15 @@
import multiprocessing
import os
import pytest
import torch
import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
init_distributed_environment, with_pynccl_for_all_reduce)
from vllm.distributed.communication_op import ( # noqa
graph_capture_mode, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import update_environment_variables
@ -41,6 +41,9 @@ def worker_fn_wrapper(fn):
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ['LOCAL_RANK']
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
fn()
@ -49,11 +52,13 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
pynccl_comm = PyNcclCommunicator()
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size
assert result == pynccl_comm.world_size
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -70,37 +75,35 @@ def multiple_tp_worker_fn():
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
comm = NCCLCommunicator(group=group, device=device)
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.all_reduce(tensor)
pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp():
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `comm.all_reduce` directly
# i.e. call `pynccl_comm.all_reduce` directly
distributed_run(multiple_tp_worker_fn, 4)
@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(torch.distributed.get_rank())
ensure_model_parallel_initialized(2, 2)
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with with_pynccl_for_all_reduce():
with graph_capture_mode():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm():
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
pynccl_comm = PyNcclCommunicator()
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**1
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph():
def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
lib = NCCLLibrary()
unique_id = lib.ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

View File

@ -1,4 +1,5 @@
from collections import namedtuple
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@ -8,7 +9,26 @@ from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
get_tp_pynccl_communicator)
@contextmanager
def graph_capture_mode():
# In graph capture, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
pynccl_comm = get_tp_pynccl_communicator()
assert pynccl_comm is not None
with pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream()):
yield
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
if is_pynccl_enabled_for_all_reduce():
pynccl_utils.all_reduce(input_)
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())

View File

@ -1,26 +1,4 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from contextlib import contextmanager
from typing import Optional, Union
# ===================== import region =====================
@ -28,217 +6,70 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
logger = init_logger(__name__)
so_file = find_nccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]
def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str:
version = ctypes.c_int()
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
class NcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
class NCCLCommunicator:
class PyNcclCommunicator:
def __init__(
self,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the NCCLCommunicator to. If None,
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
"PyNcclCommunicator should be attached to a non-NCCL group.")
self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
except Exception:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return
self.available = True
self.disabled = False
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
# construct an empty unique id
self.unique_id = ncclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
@ -246,7 +77,6 @@ class NCCLCommunicator:
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
@ -261,15 +91,25 @@ class NCCLCommunicator:
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.stream = torch.cuda.Stream()
# A small all_reduce for warmup.
self.all_reduce(torch.zeros(1, device=device))
self.stream.synchronize()
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self.disabled = True
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
@ -278,10 +118,32 @@ class NCCLCommunicator:
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)))
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
@contextmanager
def change_state(self,
enable: Optional[bool] = None,
stream: Optional[torch.cuda.Stream] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream

View File

@ -1,66 +0,0 @@
import contextlib
from typing import Optional
import torch
from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetVersion)
except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
logger.info("Failed to import NCCL library: %s", e)
logger.info("It is expected if you are not running on NVIDIA GPUs.")
pass
comm: Optional["NCCLCommunicator"] = None
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return comm is not None
@contextlib.contextmanager
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
pass
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info("vLLM is using nccl==%s", ncclGetVersion())
comm = NCCLCommunicator(group=group)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op)
def destroy_process_group() -> None:
global comm
comm = None
def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm

View File

@ -0,0 +1,258 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
logger = init_logger(__name__)
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class NCCLLibrary:
exported_functions = [
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t,
[ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t,
[ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("ncclCommInitRank", ncclResult_t, [
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
ctypes.c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
def ncclGetErrorString(self, result: ncclResult_t) -> str:
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
def NCCL_CHECK(self, result: ncclResult_t) -> None:
if result != 0:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
ctypes.byref(unique_id)))
return unique_id
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
world_size, unique_id,
rank))
return comm
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
"ncclComm_t", "cudaStream_t", "buffer_type"
]

View File

@ -3,10 +3,10 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
from typing import Optional
from typing import List, Optional
import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
@ -14,10 +14,11 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
# Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP = None
_TP_CPU_GROUP = None
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
_TP_CPU_GROUP: Optional[ProcessGroup] = None
_TP_PYNCCL_COMMUNICATOR = None
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
@ -41,11 +42,16 @@ _CPU_WORLD_GROUP = None
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
_PP_GLOBAL_RANKS: Optional[List[int]] = None
_LOCAL_RANK = -1
def get_tp_pynccl_communicator():
global _TP_PYNCCL_COMMUNICATOR
return _TP_PYNCCL_COMMUNICATOR
def get_local_rank():
global _LOCAL_RANK
return _LOCAL_RANK
@ -80,10 +86,20 @@ def init_distributed_environment(
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _LOCAL_RANK
_LOCAL_RANK = local_rank
# A small all_reduce for warmup.
data = torch.zeros(1)
if torch.cuda.is_available():
data = data.to(device=f"cuda:{local_rank}")
torch.distributed.all_reduce(data)
def initialize_model_parallel(
@ -133,29 +149,36 @@ def initialize_model_parallel(
rank = torch.distributed.get_rank()
# Build the tensor model-parallel groups.
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR
assert _TP_DEVICE_GROUP is None, (
"tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
ranks = list(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)
# Build the pipeline model-parallel groups.
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
global _PP_DEVICE_GROUP
global _PP_GLOBAL_RANKS
assert _PP_DEVICE_GROUP is None, (
"pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
_PP_DEVICE_GROUP = group
_PP_GLOBAL_RANKS = ranks
def ensure_model_parallel_initialized(
@ -188,8 +211,7 @@ def ensure_model_parallel_initialized(
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (_TP_DEVICE_GROUP is not None
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
def get_cpu_world_group():
@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group():
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
assert _PP_DEVICE_GROUP is not None, (
"pipeline model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP
return _PP_DEVICE_GROUP
def get_tensor_model_parallel_world_size():
@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank():
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
assert _PP_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
return _PIPELINE_GLOBAL_RANKS[0]
return _PP_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
assert _PP_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
return _PP_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
assert _PP_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that precedes the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
assert _PP_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def destroy_model_parallel():
@ -295,45 +317,12 @@ def destroy_model_parallel():
if _TP_CPU_GROUP:
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
_TP_CPU_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
from vllm.distributed.device_communicators import pynccl_utils
global _TP_PYNCCL_COMMUNICATOR
_TP_PYNCCL_COMMUNICATOR = None
# Destroy the pynccl states if any.
pynccl_utils.destroy_process_group()
# Whether to use pynccl for nccl all reduce.
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_pynccl_for_all_reduce():
from vllm.distributed.device_communicators import pynccl_utils
"""use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
yield
else:
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
_ENABLE_PYNCCL_FOR_ALL_REDUCE = True
stream = torch.cuda.current_stream()
with pynccl_utils.set_pynccl_stream(stream):
yield
_ENABLE_PYNCCL_FOR_ALL_REDUCE = old
def is_pynccl_enabled_for_all_reduce():
"""check if pynccl is enabled for all reduce"""
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
return _ENABLE_PYNCCL_FOR_ALL_REDUCE
global _PP_DEVICE_GROUP
if _PP_DEVICE_GROUP:
torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
_PP_DEVICE_GROUP = None
global _PP_GLOBAL_RANKS
_PP_GLOBAL_RANKS = None

View File

@ -1,4 +1,3 @@
import contextlib
import time
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
@ -12,9 +11,9 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture_mode
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
@ -917,10 +916,6 @@ class ModelRunner:
Since it is used for decoding-only, it assumes there's only 1 token
per sequence in the batch.
"""
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.pynccl_backend = pynccl_utils.get_nccl_backend()
assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To "
@ -1046,7 +1041,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with _maybe_pynccl():
with graph_capture_mode():
self.model(
input_ids,
positions,
@ -1061,7 +1056,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl():
with graph_capture_mode():
hidden_states = self.model(
input_ids,
positions,
@ -1113,16 +1108,6 @@ class CUDAGraphRunner:
return self.forward(*args, **kwargs)
@contextlib.contextmanager
def _maybe_pynccl():
if pynccl_utils.is_initialized(
) and not custom_all_reduce.is_initialized():
with with_pynccl_for_all_reduce():
yield
else:
yield
def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.

View File

@ -11,9 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
get_tensor_model_parallel_cpu_group,
init_distributed_environment)
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar)
from vllm.lora.request import LoRARequest
@ -306,29 +304,10 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
if pynccl_world_size != parallel_config.world_size:
raise RuntimeError(
"pynccl is already initialized but the pynccl world "
"size does not match parallel_config.world_size "
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# NOTE(kaichao): By default, pynccl is initialized for tp group.
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.