[Core][Distributed] refactor pynccl (#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
This commit is contained in:
parent
c833101740
commit
208b71bcc1
@ -1,15 +1,15 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
graph_capture_mode, tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
ncclGetUniqueId)
|
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||||
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
|
init_distributed_environment)
|
||||||
init_distributed_environment, with_pynccl_for_all_reduce)
|
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +41,9 @@ def worker_fn_wrapper(fn):
|
|||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapped_fn(env):
|
def wrapped_fn(env):
|
||||||
update_environment_variables(env)
|
update_environment_variables(env)
|
||||||
|
local_rank = os.environ['LOCAL_RANK']
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
fn()
|
fn()
|
||||||
|
|
||||||
@ -49,11 +52,13 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
comm = NCCLCommunicator()
|
pynccl_comm = PyNcclCommunicator()
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
tensor = torch.ones(16, 1024, 1024,
|
||||||
comm.all_reduce(tensor)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
|
with pynccl_comm.change_state(enable=True):
|
||||||
|
pynccl_comm.all_reduce(tensor)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == comm.world_size
|
assert result == pynccl_comm.world_size
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
@ -70,16 +75,17 @@ def multiple_tp_worker_fn():
|
|||||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
||||||
]
|
]
|
||||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||||
comm = NCCLCommunicator(group=group, device=device)
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
|
with pynccl_comm.change_state(enable=True):
|
||||||
# two groups can communicate independently
|
# two groups can communicate independently
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
comm.all_reduce(tensor)
|
pynccl_comm.all_reduce(tensor)
|
||||||
comm.all_reduce(tensor)
|
pynccl_comm.all_reduce(tensor)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 4
|
assert result == 4
|
||||||
else:
|
else:
|
||||||
comm.all_reduce(tensor)
|
pynccl_comm.all_reduce(tensor)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
@ -88,19 +94,16 @@ def multiple_tp_worker_fn():
|
|||||||
reason="Need at least 4 GPUs to run the test.")
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
def test_pynccl_multiple_tp():
|
def test_pynccl_multiple_tp():
|
||||||
# this tests pynccl for multiple tp groups, in a standalone way
|
# this tests pynccl for multiple tp groups, in a standalone way
|
||||||
# i.e. call `comm.all_reduce` directly
|
# i.e. call `pynccl_comm.all_reduce` directly
|
||||||
distributed_run(multiple_tp_worker_fn, 4)
|
distributed_run(multiple_tp_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def multiple_tp_with_vllm_worker_fn():
|
def multiple_tp_with_vllm_worker_fn():
|
||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
torch.cuda.set_device(torch.distributed.get_rank())
|
|
||||||
ensure_model_parallel_initialized(2, 2)
|
ensure_model_parallel_initialized(2, 2)
|
||||||
pynccl_utils.init_process_group(
|
|
||||||
group=get_tensor_model_parallel_cpu_group())
|
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
with with_pynccl_for_all_reduce():
|
with graph_capture_mode():
|
||||||
# two tp groups can communicate independently
|
# two tp groups can communicate independently
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||||
@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm():
|
|||||||
def worker_fn_with_cudagraph():
|
def worker_fn_with_cudagraph():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
comm = NCCLCommunicator()
|
pynccl_comm = PyNcclCommunicator()
|
||||||
# run something in the default stream to initialize torch engine
|
# run something in the default stream to initialize torch engine
|
||||||
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
|
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.graph(graph, stream=comm.stream):
|
with torch.cuda.graph(
|
||||||
|
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||||
|
enable=True):
|
||||||
# operation during the graph capture is recorded but not executed
|
# operation during the graph capture is recorded but not executed
|
||||||
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
||||||
comm.all_reduce(a)
|
pynccl_comm.all_reduce(a)
|
||||||
comm.stream.synchronize()
|
pynccl_comm.stream.synchronize()
|
||||||
assert a.mean().cpu().item() == comm.world_size**0
|
assert a.mean().cpu().item() == pynccl_comm.world_size**0
|
||||||
graph.replay()
|
graph.replay()
|
||||||
comm.stream.synchronize()
|
pynccl_comm.stream.synchronize()
|
||||||
assert a.mean().cpu().item() == comm.world_size**1
|
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph():
|
|||||||
|
|
||||||
|
|
||||||
def test_ncclGetUniqueId():
|
def test_ncclGetUniqueId():
|
||||||
unique_id = ncclGetUniqueId()
|
lib = NCCLLibrary()
|
||||||
|
unique_id = lib.ncclGetUniqueId()
|
||||||
# `list(unique_id.internal)` is something like this:
|
# `list(unique_id.internal)` is something like this:
|
||||||
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
|
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
|
||||||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +9,26 @@ from .parallel_state import (get_cpu_world_group,
|
|||||||
get_tensor_model_parallel_group,
|
get_tensor_model_parallel_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
is_pynccl_enabled_for_all_reduce)
|
get_tp_pynccl_communicator)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def graph_capture_mode():
|
||||||
|
# In graph capture, we have to be very careful about the collective
|
||||||
|
# operations. The current status is:
|
||||||
|
# allreduce \ Mode | Eager | Graph |
|
||||||
|
# --------------------------------------------
|
||||||
|
# custom allreduce | enabled | enabled |
|
||||||
|
# PyNccl | disabled| enabled |
|
||||||
|
# torch.distributed | enabled | disabled|
|
||||||
|
#
|
||||||
|
# Note that custom allreduce will have a runtime check, if the tensor size
|
||||||
|
# is too large, it will fallback to the next available option.
|
||||||
|
pynccl_comm = get_tp_pynccl_communicator()
|
||||||
|
assert pynccl_comm is not None
|
||||||
|
with pynccl_comm.change_state(enable=True,
|
||||||
|
stream=torch.cuda.current_stream()):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||||
@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
|||||||
TLDR: always assume this function modifies its input, but use the return
|
TLDR: always assume this function modifies its input, but use the return
|
||||||
value as the output.
|
value as the output.
|
||||||
"""
|
"""
|
||||||
from vllm.distributed.device_communicators import pynccl_utils
|
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
custom_all_reduce)
|
custom_all_reduce)
|
||||||
|
|
||||||
@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
|||||||
out = custom_all_reduce(input_)
|
out = custom_all_reduce(input_)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
return out
|
return out
|
||||||
if is_pynccl_enabled_for_all_reduce():
|
pynccl_comm = get_tp_pynccl_communicator()
|
||||||
pynccl_utils.all_reduce(input_)
|
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||||
|
pynccl_comm.all_reduce(input_)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(input_,
|
torch.distributed.all_reduce(input_,
|
||||||
group=get_tensor_model_parallel_group())
|
group=get_tensor_model_parallel_group())
|
||||||
|
@ -1,26 +1,4 @@
|
|||||||
# This file is a pure Python wrapper for the NCCL library.
|
from contextlib import contextmanager
|
||||||
# The main purpose is to use NCCL combined with CUDA graph.
|
|
||||||
# Before writing this script, we tried the following approach:
|
|
||||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
|
||||||
# often gets stuck when initializing the NCCL communicator.
|
|
||||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
|
||||||
# contains many other potential cuda APIs, that are not allowed during
|
|
||||||
# capturing the CUDA graph. For further details, please check
|
|
||||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
|
||||||
#
|
|
||||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
|
||||||
# doable, but we often encounter issues related with nccl versions, and need
|
|
||||||
# to switch between different versions of NCCL. See
|
|
||||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
|
||||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
|
||||||
# recompilation of the code every time we want to switch between different
|
|
||||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
|
||||||
# more flexible. We can easily switch between different versions of NCCL by
|
|
||||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
|
||||||
# variable in the code.
|
|
||||||
|
|
||||||
import ctypes
|
|
||||||
import platform
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
# ===================== import region =====================
|
# ===================== import region =====================
|
||||||
@ -28,217 +6,70 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
|
|
||||||
|
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||||
|
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||||
|
ncclRedOpTypeEnum, ncclUniqueId)
|
||||||
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
so_file = find_nccl_library()
|
|
||||||
|
|
||||||
try:
|
class PyNcclCommunicator:
|
||||||
# load the library in another process.
|
|
||||||
# if it core dumps, it will not crash the current process
|
|
||||||
nccl_integrity_check(so_file)
|
|
||||||
nccl = ctypes.CDLL(so_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Failed to load NCCL library from %s ."
|
|
||||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
|
||||||
"Otherwise, the nccl library might not exist, be corrupted "
|
|
||||||
"or it does not support the current platform %s."
|
|
||||||
"One solution is to download libnccl2 version 2.18 from "
|
|
||||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
|
||||||
"and extract the libnccl.so.2 file. If you already have the "
|
|
||||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
|
||||||
" to point to the correct nccl library path.", so_file,
|
|
||||||
platform.platform())
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# === export types and functions from nccl to Python ===
|
|
||||||
# for the original nccl definition, please check
|
|
||||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
|
||||||
|
|
||||||
ncclResult_t = ctypes.c_int
|
|
||||||
|
|
||||||
_c_ncclGetErrorString = nccl.ncclGetErrorString
|
|
||||||
_c_ncclGetErrorString.restype = ctypes.c_char_p
|
|
||||||
_c_ncclGetErrorString.argtypes = [ncclResult_t]
|
|
||||||
|
|
||||||
|
|
||||||
def NCCL_CHECK(result: ncclResult_t) -> None:
|
|
||||||
if result != 0:
|
|
||||||
error_str = _c_ncclGetErrorString(result)
|
|
||||||
error_str = error_str.decode("utf-8")
|
|
||||||
raise RuntimeError(f"NCCL error: {error_str}")
|
|
||||||
|
|
||||||
|
|
||||||
# equivalent to c declaration:
|
|
||||||
# ncclResult_t ncclGetVersion(int *version);
|
|
||||||
_c_ncclGetVersion = nccl.ncclGetVersion
|
|
||||||
_c_ncclGetVersion.restype = ctypes.c_int
|
|
||||||
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
|
||||||
|
|
||||||
|
|
||||||
def ncclGetVersion() -> str:
|
|
||||||
version = ctypes.c_int()
|
|
||||||
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
|
|
||||||
# something like 21903 --> "2.19.3"
|
|
||||||
version_str = str(version.value)
|
|
||||||
major = version_str[0].lstrip("0")
|
|
||||||
minor = version_str[1:3].lstrip("0")
|
|
||||||
patch = version_str[3:].lstrip("0")
|
|
||||||
return f"{major}.{minor}.{patch}"
|
|
||||||
|
|
||||||
|
|
||||||
class NcclUniqueId(ctypes.Structure):
|
|
||||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
|
||||||
|
|
||||||
|
|
||||||
# equivalent to c declaration:
|
|
||||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
|
||||||
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
|
|
||||||
_c_ncclGetUniqueId.restype = ctypes.c_int
|
|
||||||
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
|
||||||
|
|
||||||
|
|
||||||
def ncclGetUniqueId() -> NcclUniqueId:
|
|
||||||
unique_id = NcclUniqueId()
|
|
||||||
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
|
|
||||||
return unique_id
|
|
||||||
|
|
||||||
|
|
||||||
# equivalent to c declaration:
|
|
||||||
# ncclResult_t ncclCommInitRank(
|
|
||||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
|
||||||
# note that ncclComm_t is a pointer type, so the first argument
|
|
||||||
# is a pointer to a pointer
|
|
||||||
_c_ncclCommInitRank = nccl.ncclCommInitRank
|
|
||||||
_c_ncclCommInitRank.restype = ctypes.c_int
|
|
||||||
_c_ncclCommInitRank.argtypes = [
|
|
||||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
|
||||||
]
|
|
||||||
|
|
||||||
ncclDataType_t = ctypes.c_int
|
|
||||||
|
|
||||||
|
|
||||||
class ncclDataTypeEnum:
|
|
||||||
ncclInt8 = 0
|
|
||||||
ncclChar = 0
|
|
||||||
ncclUint8 = 1
|
|
||||||
ncclInt32 = 2
|
|
||||||
ncclInt = 2
|
|
||||||
ncclUint32 = 3
|
|
||||||
ncclInt64 = 4
|
|
||||||
ncclUint64 = 5
|
|
||||||
ncclFloat16 = 6
|
|
||||||
ncclHalf = 6
|
|
||||||
ncclFloat32 = 7
|
|
||||||
ncclFloat = 7
|
|
||||||
ncclFloat64 = 8
|
|
||||||
ncclDouble = 8
|
|
||||||
ncclBfloat16 = 9
|
|
||||||
ncclNumTypes = 10
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
|
||||||
if dtype == torch.int8:
|
|
||||||
return cls.ncclInt8
|
|
||||||
if dtype == torch.uint8:
|
|
||||||
return cls.ncclUint8
|
|
||||||
if dtype == torch.int32:
|
|
||||||
return cls.ncclInt32
|
|
||||||
if dtype == torch.int64:
|
|
||||||
return cls.ncclInt64
|
|
||||||
if dtype == torch.float16:
|
|
||||||
return cls.ncclFloat16
|
|
||||||
if dtype == torch.float32:
|
|
||||||
return cls.ncclFloat32
|
|
||||||
if dtype == torch.float64:
|
|
||||||
return cls.ncclFloat64
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
return cls.ncclBfloat16
|
|
||||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
||||||
|
|
||||||
|
|
||||||
ncclRedOp_t = ctypes.c_int
|
|
||||||
|
|
||||||
|
|
||||||
class ncclRedOpTypeEnum:
|
|
||||||
ncclSum = 0
|
|
||||||
ncclProd = 1
|
|
||||||
ncclMax = 2
|
|
||||||
ncclMin = 3
|
|
||||||
ncclAvg = 4
|
|
||||||
ncclNumOps = 5
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_torch(cls, op: ReduceOp) -> int:
|
|
||||||
if op == ReduceOp.SUM:
|
|
||||||
return cls.ncclSum
|
|
||||||
if op == ReduceOp.PRODUCT:
|
|
||||||
return cls.ncclProd
|
|
||||||
if op == ReduceOp.MAX:
|
|
||||||
return cls.ncclMax
|
|
||||||
if op == ReduceOp.MIN:
|
|
||||||
return cls.ncclMin
|
|
||||||
if op == ReduceOp.AVG:
|
|
||||||
return cls.ncclAvg
|
|
||||||
raise ValueError(f"Unsupported op: {op}")
|
|
||||||
|
|
||||||
|
|
||||||
# equivalent to c declaration:
|
|
||||||
# ncclResult_t ncclAllReduce(
|
|
||||||
# const void* sendbuff, void* recvbuff, size_t count,
|
|
||||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
|
||||||
# udaStream_t stream);
|
|
||||||
# note that cudaStream_t is a pointer type, so the last argument is a pointer
|
|
||||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
|
||||||
_c_ncclAllReduce.restype = ctypes.c_int
|
|
||||||
_c_ncclAllReduce.argtypes = [
|
|
||||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
|
||||||
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
|
||||||
]
|
|
||||||
|
|
||||||
# be cautious! this is a collective call, it will block until all
|
|
||||||
# processes in the communicator have called this function.
|
|
||||||
# because Python object destruction can happen in random order,
|
|
||||||
# it is better not to call it at all.
|
|
||||||
# equivalent to c declaration:
|
|
||||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
|
||||||
_c_ncclCommDestroy = nccl.ncclCommDestroy
|
|
||||||
_c_ncclCommDestroy.restype = ctypes.c_int
|
|
||||||
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
|
|
||||||
|
|
||||||
|
|
||||||
class NCCLCommunicator:
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
device: Optional[Union[int, str, torch.device]] = None,
|
device: Optional[Union[int, str, torch.device]] = None,
|
||||||
|
library_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
group: the process group to work on. If None, it will use the
|
group: the process group to work on. If None, it will use the
|
||||||
default process group.
|
default process group.
|
||||||
device: the device to bind the NCCLCommunicator to. If None,
|
device: the device to bind the PyNcclCommunicator to. If None,
|
||||||
it will be bind to f"cuda:{local_rank}".
|
it will be bind to f"cuda:{local_rank}".
|
||||||
|
library_path: the path to the NCCL library. If None, it will
|
||||||
|
use the default library path.
|
||||||
It is the caller's responsibility to make sure each communicator
|
It is the caller's responsibility to make sure each communicator
|
||||||
is bind to a unique device.
|
is bind to a unique device.
|
||||||
"""
|
"""
|
||||||
assert dist.is_initialized()
|
assert dist.is_initialized()
|
||||||
group = get_cpu_world_group() if group is None else group
|
group = get_cpu_world_group() if group is None else group
|
||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
"NCCLCommunicator should be attached to a non-NCCL group.")
|
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||||
self.group = group
|
self.group = group
|
||||||
# note: this rank is the rank in the group
|
# note: this rank is the rank in the group
|
||||||
self.rank = dist.get_rank(group)
|
self.rank = dist.get_rank(group)
|
||||||
self.world_size = dist.get_world_size(group)
|
self.world_size = dist.get_world_size(group)
|
||||||
|
|
||||||
|
# if world_size == 1, no need to create communicator
|
||||||
|
if self.world_size == 1:
|
||||||
|
self.available = False
|
||||||
|
self.disabled = True
|
||||||
|
self.stream = None
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.nccl = NCCLLibrary(library_path)
|
||||||
|
except Exception:
|
||||||
|
# disable because of missing NCCL library
|
||||||
|
# e.g. in a non-GPU environment
|
||||||
|
self.available = False
|
||||||
|
self.disabled = True
|
||||||
|
self.stream = None
|
||||||
|
return
|
||||||
|
|
||||||
|
self.available = True
|
||||||
|
self.disabled = False
|
||||||
|
|
||||||
|
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.unique_id = ncclGetUniqueId()
|
# get the unique id from NCCL
|
||||||
|
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||||
else:
|
else:
|
||||||
self.unique_id = NcclUniqueId()
|
# construct an empty unique id
|
||||||
|
self.unique_id = ncclUniqueId()
|
||||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||||
ranks = dist.get_process_group_ranks(group)
|
ranks = dist.get_process_group_ranks(group)
|
||||||
# arg `src` in `broadcast` is the global rank
|
# arg `src` in `broadcast` is the global rank
|
||||||
@ -246,7 +77,6 @@ class NCCLCommunicator:
|
|||||||
byte_list = tensor.tolist()
|
byte_list = tensor.tolist()
|
||||||
for i, byte in enumerate(byte_list):
|
for i, byte in enumerate(byte_list):
|
||||||
self.unique_id.internal[i] = byte
|
self.unique_id.internal[i] = byte
|
||||||
self.comm = ctypes.c_void_p()
|
|
||||||
if device is None:
|
if device is None:
|
||||||
local_rank = get_local_rank()
|
local_rank = get_local_rank()
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@ -261,15 +91,25 @@ class NCCLCommunicator:
|
|||||||
# `torch.cuda.device` is a context manager that changes the
|
# `torch.cuda.device` is a context manager that changes the
|
||||||
# current cuda device to the specified one
|
# current cuda device to the specified one
|
||||||
with torch.cuda.device(device):
|
with torch.cuda.device(device):
|
||||||
NCCL_CHECK(
|
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||||
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
self.world_size, self.unique_id, self.rank)
|
||||||
self.unique_id, self.rank))
|
|
||||||
self.stream = torch.cuda.Stream()
|
self.stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
self.all_reduce(torch.zeros(1, device=device))
|
||||||
|
self.stream.synchronize()
|
||||||
|
|
||||||
|
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||||
|
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||||
|
# when we are using CUDA graph.
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
op: ReduceOp = ReduceOp.SUM,
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
stream=None):
|
stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
# nccl communicator created on a specific device
|
# nccl communicator created on a specific device
|
||||||
# will only work on tensors on the same device
|
# will only work on tensors on the same device
|
||||||
# otherwise it will cause "illegal memory access"
|
# otherwise it will cause "illegal memory access"
|
||||||
@ -278,10 +118,32 @@ class NCCLCommunicator:
|
|||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
NCCL_CHECK(
|
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
|
||||||
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
ctypes.c_void_p(tensor.data_ptr()),
|
|
||||||
tensor.numel(),
|
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
ctypes.c_void_p(stream.cuda_stream)))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def change_state(self,
|
||||||
|
enable: Optional[bool] = None,
|
||||||
|
stream: Optional[torch.cuda.Stream] = None):
|
||||||
|
"""
|
||||||
|
A context manager to change the state of the communicator.
|
||||||
|
"""
|
||||||
|
if enable is None:
|
||||||
|
# guess a default value when not specified
|
||||||
|
enable = self.available
|
||||||
|
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
|
||||||
|
old_disable = self.disabled
|
||||||
|
old_stream = self.stream
|
||||||
|
|
||||||
|
self.stream = stream
|
||||||
|
self.disabled = not enable
|
||||||
|
yield
|
||||||
|
|
||||||
|
self.disabled = old_disable
|
||||||
|
self.stream = old_stream
|
||||||
|
@ -1,66 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
|
||||||
ncclGetVersion)
|
|
||||||
except Exception as e:
|
|
||||||
# in non-NVIDIA environments, we can't import the nccl module
|
|
||||||
# e.g. when running on machines with AMD GPUs
|
|
||||||
logger.info("Failed to import NCCL library: %s", e)
|
|
||||||
logger.info("It is expected if you are not running on NVIDIA GPUs.")
|
|
||||||
pass
|
|
||||||
|
|
||||||
comm: Optional["NCCLCommunicator"] = None
|
|
||||||
|
|
||||||
|
|
||||||
def is_initialized() -> bool:
|
|
||||||
"""Returns whether the NCCL backend is initialized."""
|
|
||||||
return comm is not None
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
|
||||||
"""Set the cuda stream for communication"""
|
|
||||||
try:
|
|
||||||
assert comm is not None
|
|
||||||
comm.stream = stream
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
|
|
||||||
assert not is_initialized()
|
|
||||||
global comm
|
|
||||||
logger.info("vLLM is using nccl==%s", ncclGetVersion())
|
|
||||||
comm = NCCLCommunicator(group=group)
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
|
||||||
"""All-reduces the input tensor across the process group."""
|
|
||||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
|
||||||
assert comm is not None
|
|
||||||
comm.all_reduce(input_, op)
|
|
||||||
|
|
||||||
|
|
||||||
def destroy_process_group() -> None:
|
|
||||||
global comm
|
|
||||||
comm = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
|
||||||
"""Returns the world size."""
|
|
||||||
assert comm is not None
|
|
||||||
return comm.world_size
|
|
||||||
|
|
||||||
|
|
||||||
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
|
|
||||||
return comm
|
|
258
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
258
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
# This file is a pure Python wrapper for the NCCL library.
|
||||||
|
# The main purpose is to use NCCL combined with CUDA graph.
|
||||||
|
# Before writing this script, we tried the following approach:
|
||||||
|
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||||
|
# often gets stuck when initializing the NCCL communicator.
|
||||||
|
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||||
|
# contains many other potential cuda APIs, that are not allowed during
|
||||||
|
# capturing the CUDA graph. For further details, please check
|
||||||
|
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||||
|
#
|
||||||
|
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||||
|
# doable, but we often encounter issues related with nccl versions, and need
|
||||||
|
# to switch between different versions of NCCL. See
|
||||||
|
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||||
|
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||||
|
# recompilation of the code every time we want to switch between different
|
||||||
|
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||||
|
# more flexible. We can easily switch between different versions of NCCL by
|
||||||
|
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||||
|
# variable in the code.
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import platform
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# === export types and functions from nccl to Python ===
|
||||||
|
# for the original nccl definition, please check
|
||||||
|
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||||
|
|
||||||
|
ncclResult_t = ctypes.c_int
|
||||||
|
ncclComm_t = ctypes.c_void_p
|
||||||
|
|
||||||
|
|
||||||
|
class ncclUniqueId(ctypes.Structure):
|
||||||
|
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||||
|
|
||||||
|
|
||||||
|
cudaStream_t = ctypes.c_void_p
|
||||||
|
buffer_type = ctypes.c_void_p
|
||||||
|
|
||||||
|
ncclDataType_t = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
class ncclDataTypeEnum:
|
||||||
|
ncclInt8 = 0
|
||||||
|
ncclChar = 0
|
||||||
|
ncclUint8 = 1
|
||||||
|
ncclInt32 = 2
|
||||||
|
ncclInt = 2
|
||||||
|
ncclUint32 = 3
|
||||||
|
ncclInt64 = 4
|
||||||
|
ncclUint64 = 5
|
||||||
|
ncclFloat16 = 6
|
||||||
|
ncclHalf = 6
|
||||||
|
ncclFloat32 = 7
|
||||||
|
ncclFloat = 7
|
||||||
|
ncclFloat64 = 8
|
||||||
|
ncclDouble = 8
|
||||||
|
ncclBfloat16 = 9
|
||||||
|
ncclNumTypes = 10
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return cls.ncclInt8
|
||||||
|
if dtype == torch.uint8:
|
||||||
|
return cls.ncclUint8
|
||||||
|
if dtype == torch.int32:
|
||||||
|
return cls.ncclInt32
|
||||||
|
if dtype == torch.int64:
|
||||||
|
return cls.ncclInt64
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return cls.ncclFloat16
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return cls.ncclFloat32
|
||||||
|
if dtype == torch.float64:
|
||||||
|
return cls.ncclFloat64
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return cls.ncclBfloat16
|
||||||
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
ncclRedOp_t = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
class ncclRedOpTypeEnum:
|
||||||
|
ncclSum = 0
|
||||||
|
ncclProd = 1
|
||||||
|
ncclMax = 2
|
||||||
|
ncclMin = 3
|
||||||
|
ncclAvg = 4
|
||||||
|
ncclNumOps = 5
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_torch(cls, op: ReduceOp) -> int:
|
||||||
|
if op == ReduceOp.SUM:
|
||||||
|
return cls.ncclSum
|
||||||
|
if op == ReduceOp.PRODUCT:
|
||||||
|
return cls.ncclProd
|
||||||
|
if op == ReduceOp.MAX:
|
||||||
|
return cls.ncclMax
|
||||||
|
if op == ReduceOp.MIN:
|
||||||
|
return cls.ncclMin
|
||||||
|
if op == ReduceOp.AVG:
|
||||||
|
return cls.ncclAvg
|
||||||
|
raise ValueError(f"Unsupported op: {op}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Function:
|
||||||
|
name: str
|
||||||
|
restype: Any
|
||||||
|
argtypes: List[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class NCCLLibrary:
|
||||||
|
exported_functions = [
|
||||||
|
# const char* ncclGetErrorString(ncclResult_t result)
|
||||||
|
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||||
|
# ncclResult_t ncclGetVersion(int *version);
|
||||||
|
Function("ncclGetVersion", ncclResult_t,
|
||||||
|
[ctypes.POINTER(ctypes.c_int)]),
|
||||||
|
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||||
|
Function("ncclGetUniqueId", ncclResult_t,
|
||||||
|
[ctypes.POINTER(ncclUniqueId)]),
|
||||||
|
# ncclResult_t ncclCommInitRank(
|
||||||
|
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||||
|
# note that ncclComm_t is a pointer type, so the first argument
|
||||||
|
# is a pointer to a pointer
|
||||||
|
Function("ncclCommInitRank", ncclResult_t, [
|
||||||
|
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||||
|
ctypes.c_int
|
||||||
|
]),
|
||||||
|
# ncclResult_t ncclAllReduce(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
# note that cudaStream_t is a pointer type, so the last argument
|
||||||
|
# is a pointer
|
||||||
|
Function("ncclAllReduce", ncclResult_t, [
|
||||||
|
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||||
|
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
|
# be cautious! this is a collective call, it will block until all
|
||||||
|
# processes in the communicator have called this function.
|
||||||
|
# because Python object destruction can happen in random order,
|
||||||
|
# it is better not to call it at all.
|
||||||
|
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||||
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# class attribute to store the mapping from the path to the library
|
||||||
|
# to avoid loading the same library multiple times
|
||||||
|
path_to_library_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# class attribute to store the mapping from library path
|
||||||
|
# to the corresponding dictionary
|
||||||
|
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def __init__(self, so_file: Optional[str] = None):
|
||||||
|
|
||||||
|
so_file = so_file or find_nccl_library()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# load the library in another process.
|
||||||
|
# if it core dumps, it will not crash the current process
|
||||||
|
nccl_integrity_check(so_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to load NCCL library from %s ."
|
||||||
|
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||||
|
"Otherwise, the nccl library might not exist, be corrupted "
|
||||||
|
"or it does not support the current platform %s."
|
||||||
|
"One solution is to download libnccl2 version 2.18 from "
|
||||||
|
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
||||||
|
"and extract the libnccl.so.2 file. If you already have the "
|
||||||
|
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
||||||
|
" to point to the correct nccl library path.", so_file,
|
||||||
|
platform.platform())
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||||
|
lib = ctypes.CDLL(so_file)
|
||||||
|
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||||
|
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||||
|
|
||||||
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||||
|
_funcs = {}
|
||||||
|
for func in NCCLLibrary.exported_functions:
|
||||||
|
f = getattr(self.lib, func.name)
|
||||||
|
f.restype = func.restype
|
||||||
|
f.argtypes = func.argtypes
|
||||||
|
_funcs[func.name] = f
|
||||||
|
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||||
|
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||||
|
|
||||||
|
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||||
|
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
||||||
|
|
||||||
|
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||||
|
if result != 0:
|
||||||
|
error_str = self.ncclGetErrorString(result)
|
||||||
|
raise RuntimeError(f"NCCL error: {error_str}")
|
||||||
|
|
||||||
|
def ncclGetVersion(self) -> str:
|
||||||
|
version = ctypes.c_int()
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||||
|
version_str = str(version.value)
|
||||||
|
# something like 21903 --> "2.19.3"
|
||||||
|
major = version_str[0].lstrip("0")
|
||||||
|
minor = version_str[1:3].lstrip("0")
|
||||||
|
patch = version_str[3:].lstrip("0")
|
||||||
|
return f"{major}.{minor}.{patch}"
|
||||||
|
|
||||||
|
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||||
|
unique_id = ncclUniqueId()
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
|
||||||
|
ctypes.byref(unique_id)))
|
||||||
|
return unique_id
|
||||||
|
|
||||||
|
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||||
|
rank: int) -> ncclComm_t:
|
||||||
|
comm = ncclComm_t()
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
|
||||||
|
world_size, unique_id,
|
||||||
|
rank))
|
||||||
|
return comm
|
||||||
|
|
||||||
|
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||||
|
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t) -> None:
|
||||||
|
# `datatype` actually should be `ncclDataType_t`
|
||||||
|
# and `op` should be `ncclRedOp_t`
|
||||||
|
# both are aliases of `ctypes.c_int`
|
||||||
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||||
|
# by ctypes automatically
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
|
||||||
|
datatype, op, comm,
|
||||||
|
stream))
|
||||||
|
|
||||||
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||||
|
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||||
|
]
|
@ -3,10 +3,10 @@
|
|||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""Tensor and pipeline parallel groups."""
|
"""Tensor and pipeline parallel groups."""
|
||||||
import contextlib
|
from typing import List, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -14,10 +14,11 @@ from vllm.logger import init_logger
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# Tensor model parallel group that the current rank belongs to.
|
# Tensor model parallel group that the current rank belongs to.
|
||||||
_TP_DEVICE_GROUP = None
|
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||||
_TP_CPU_GROUP = None
|
_TP_CPU_GROUP: Optional[ProcessGroup] = None
|
||||||
|
_TP_PYNCCL_COMMUNICATOR = None
|
||||||
# Pipeline model parallel group that the current rank belongs to.
|
# Pipeline model parallel group that the current rank belongs to.
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||||
|
|
||||||
# when people blindly call `torch.distributed.all_reduce` etc,
|
# when people blindly call `torch.distributed.all_reduce` etc,
|
||||||
# it will use this group. It is initialized with the `backend`
|
# it will use this group. It is initialized with the `backend`
|
||||||
@ -41,11 +42,16 @@ _CPU_WORLD_GROUP = None
|
|||||||
|
|
||||||
# A list of global ranks for each pipeline group to ease calculation of the
|
# A list of global ranks for each pipeline group to ease calculation of the
|
||||||
# source rank when broadcasting from the first or last pipeline stage.
|
# source rank when broadcasting from the first or last pipeline stage.
|
||||||
_PIPELINE_GLOBAL_RANKS = None
|
_PP_GLOBAL_RANKS: Optional[List[int]] = None
|
||||||
|
|
||||||
_LOCAL_RANK = -1
|
_LOCAL_RANK = -1
|
||||||
|
|
||||||
|
|
||||||
|
def get_tp_pynccl_communicator():
|
||||||
|
global _TP_PYNCCL_COMMUNICATOR
|
||||||
|
return _TP_PYNCCL_COMMUNICATOR
|
||||||
|
|
||||||
|
|
||||||
def get_local_rank():
|
def get_local_rank():
|
||||||
global _LOCAL_RANK
|
global _LOCAL_RANK
|
||||||
return _LOCAL_RANK
|
return _LOCAL_RANK
|
||||||
@ -80,10 +86,20 @@ def init_distributed_environment(
|
|||||||
# set the local rank
|
# set the local rank
|
||||||
# local_rank is not available in torch ProcessGroup,
|
# local_rank is not available in torch ProcessGroup,
|
||||||
# see https://github.com/pytorch/pytorch/issues/122816
|
# see https://github.com/pytorch/pytorch/issues/122816
|
||||||
if local_rank == -1 and distributed_init_method == "env://":
|
if local_rank == -1:
|
||||||
|
# local rank not set, this usually happens in single-node
|
||||||
|
# setting, where we can use rank as local rank
|
||||||
|
if distributed_init_method == "env://":
|
||||||
local_rank = envs.LOCAL_RANK
|
local_rank = envs.LOCAL_RANK
|
||||||
|
else:
|
||||||
|
local_rank = rank
|
||||||
global _LOCAL_RANK
|
global _LOCAL_RANK
|
||||||
_LOCAL_RANK = local_rank
|
_LOCAL_RANK = local_rank
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
data = torch.zeros(1)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
data = data.to(device=f"cuda:{local_rank}")
|
||||||
|
torch.distributed.all_reduce(data)
|
||||||
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
@ -133,29 +149,36 @@ def initialize_model_parallel(
|
|||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
# Build the tensor model-parallel groups.
|
# Build the tensor model-parallel groups.
|
||||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR
|
||||||
assert _TP_DEVICE_GROUP is None, (
|
assert _TP_DEVICE_GROUP is None, (
|
||||||
"tensor model parallel group is already initialized")
|
"tensor model parallel group is already initialized")
|
||||||
for i in range(num_tensor_model_parallel_groups):
|
for i in range(num_tensor_model_parallel_groups):
|
||||||
ranks = range(i * tensor_model_parallel_size,
|
ranks = list(
|
||||||
(i + 1) * tensor_model_parallel_size)
|
range(i * tensor_model_parallel_size,
|
||||||
|
(i + 1) * tensor_model_parallel_size))
|
||||||
group = torch.distributed.new_group(ranks, backend=backend)
|
group = torch.distributed.new_group(ranks, backend=backend)
|
||||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
_TP_DEVICE_GROUP = group
|
_TP_DEVICE_GROUP = group
|
||||||
_TP_CPU_GROUP = cpu_group
|
_TP_CPU_GROUP = cpu_group
|
||||||
|
|
||||||
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
||||||
|
group=_TP_CPU_GROUP,
|
||||||
|
device=_LOCAL_RANK,
|
||||||
|
)
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
global _PP_DEVICE_GROUP
|
||||||
global _PIPELINE_GLOBAL_RANKS
|
global _PP_GLOBAL_RANKS
|
||||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
|
assert _PP_DEVICE_GROUP is None, (
|
||||||
"pipeline model parallel group is already initialized")
|
"pipeline model parallel group is already initialized")
|
||||||
for i in range(num_pipeline_model_parallel_groups):
|
for i in range(num_pipeline_model_parallel_groups):
|
||||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||||
group = torch.distributed.new_group(ranks, backend=backend)
|
group = torch.distributed.new_group(ranks, backend=backend)
|
||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
_PP_DEVICE_GROUP = group
|
||||||
_PIPELINE_GLOBAL_RANKS = ranks
|
_PP_GLOBAL_RANKS = ranks
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
@ -188,8 +211,7 @@ def ensure_model_parallel_initialized(
|
|||||||
|
|
||||||
def model_parallel_is_initialized():
|
def model_parallel_is_initialized():
|
||||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||||
return (_TP_DEVICE_GROUP is not None
|
return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
|
||||||
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_world_group():
|
def get_cpu_world_group():
|
||||||
@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group():
|
|||||||
|
|
||||||
def get_pipeline_model_parallel_group():
|
def get_pipeline_model_parallel_group():
|
||||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
|
assert _PP_DEVICE_GROUP is not None, (
|
||||||
"pipeline model parallel group is not initialized")
|
"pipeline model parallel group is not initialized")
|
||||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
return _PP_DEVICE_GROUP
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_world_size():
|
def get_tensor_model_parallel_world_size():
|
||||||
@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank():
|
|||||||
def get_pipeline_model_parallel_first_rank():
|
def get_pipeline_model_parallel_first_rank():
|
||||||
"""Return the global rank of the first process in the pipeline for the
|
"""Return the global rank of the first process in the pipeline for the
|
||||||
current tensor parallel group"""
|
current tensor parallel group"""
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
assert _PP_GLOBAL_RANKS is not None, (
|
||||||
"Pipeline parallel group is not initialized")
|
"Pipeline parallel group is not initialized")
|
||||||
return _PIPELINE_GLOBAL_RANKS[0]
|
return _PP_GLOBAL_RANKS[0]
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_last_rank():
|
def get_pipeline_model_parallel_last_rank():
|
||||||
"""Return the global rank of the last process in the pipeline for the
|
"""Return the global rank of the last process in the pipeline for the
|
||||||
current tensor parallel group"""
|
current tensor parallel group"""
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
assert _PP_GLOBAL_RANKS is not None, (
|
||||||
"Pipeline parallel group is not initialized")
|
"Pipeline parallel group is not initialized")
|
||||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
return _PP_GLOBAL_RANKS[last_rank_local]
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_next_rank():
|
def get_pipeline_model_parallel_next_rank():
|
||||||
"""Return the global rank that follows the caller in the pipeline"""
|
"""Return the global rank that follows the caller in the pipeline"""
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
assert _PP_GLOBAL_RANKS is not None, (
|
||||||
"Pipeline parallel group is not initialized")
|
"Pipeline parallel group is not initialized")
|
||||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||||
world_size = get_pipeline_model_parallel_world_size()
|
world_size = get_pipeline_model_parallel_world_size()
|
||||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_prev_rank():
|
def get_pipeline_model_parallel_prev_rank():
|
||||||
"""Return the global rank that precedes the caller in the pipeline"""
|
"""Return the global rank that precedes the caller in the pipeline"""
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
assert _PP_GLOBAL_RANKS is not None, (
|
||||||
"Pipeline parallel group is not initialized")
|
"Pipeline parallel group is not initialized")
|
||||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||||
world_size = get_pipeline_model_parallel_world_size()
|
world_size = get_pipeline_model_parallel_world_size()
|
||||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||||
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
@ -295,45 +317,12 @@ def destroy_model_parallel():
|
|||||||
if _TP_CPU_GROUP:
|
if _TP_CPU_GROUP:
|
||||||
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
|
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
|
||||||
_TP_CPU_GROUP = None
|
_TP_CPU_GROUP = None
|
||||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
global _TP_PYNCCL_COMMUNICATOR
|
||||||
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
_TP_PYNCCL_COMMUNICATOR = None
|
||||||
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
||||||
global _PIPELINE_GLOBAL_RANKS
|
|
||||||
_PIPELINE_GLOBAL_RANKS = None
|
|
||||||
from vllm.distributed.device_communicators import pynccl_utils
|
|
||||||
|
|
||||||
# Destroy the pynccl states if any.
|
global _PP_DEVICE_GROUP
|
||||||
pynccl_utils.destroy_process_group()
|
if _PP_DEVICE_GROUP:
|
||||||
|
torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
|
||||||
|
_PP_DEVICE_GROUP = None
|
||||||
# Whether to use pynccl for nccl all reduce.
|
global _PP_GLOBAL_RANKS
|
||||||
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
|
_PP_GLOBAL_RANKS = None
|
||||||
# is not well supported by CUDA graph.
|
|
||||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def with_pynccl_for_all_reduce():
|
|
||||||
from vllm.distributed.device_communicators import pynccl_utils
|
|
||||||
"""use pynccl instead of torch.distributed for all reduce"""
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if tp_size == 1:
|
|
||||||
# No-op.
|
|
||||||
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
||||||
old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
||||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = True
|
|
||||||
|
|
||||||
stream = torch.cuda.current_stream()
|
|
||||||
with pynccl_utils.set_pynccl_stream(stream):
|
|
||||||
yield
|
|
||||||
_ENABLE_PYNCCL_FOR_ALL_REDUCE = old
|
|
||||||
|
|
||||||
|
|
||||||
def is_pynccl_enabled_for_all_reduce():
|
|
||||||
"""check if pynccl is enabled for all reduce"""
|
|
||||||
global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
||||||
return _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import time
|
import time
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
||||||
@ -12,9 +11,9 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
from vllm.distributed.device_communicators import (custom_all_reduce,
|
from vllm.distributed.communication_op import graph_capture_mode
|
||||||
pynccl_utils)
|
from vllm.distributed.device_communicators import custom_all_reduce
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -917,10 +916,6 @@ class ModelRunner:
|
|||||||
Since it is used for decoding-only, it assumes there's only 1 token
|
Since it is used for decoding-only, it assumes there's only 1 token
|
||||||
per sequence in the batch.
|
per sequence in the batch.
|
||||||
"""
|
"""
|
||||||
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
|
||||||
# deleted before the CUDA graphs.
|
|
||||||
self.pynccl_backend = pynccl_utils.get_nccl_backend()
|
|
||||||
|
|
||||||
assert not self.model_config.enforce_eager
|
assert not self.model_config.enforce_eager
|
||||||
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
||||||
"unexpected consequences if the model is not static. To "
|
"unexpected consequences if the model is not static. To "
|
||||||
@ -1046,7 +1041,7 @@ class CUDAGraphRunner:
|
|||||||
# Run the model once without capturing the graph.
|
# Run the model once without capturing the graph.
|
||||||
# This is to make sure that the captured graph does not include the
|
# This is to make sure that the captured graph does not include the
|
||||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||||
with _maybe_pynccl():
|
with graph_capture_mode():
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
@ -1061,7 +1056,7 @@ class CUDAGraphRunner:
|
|||||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||||
self._graph = torch.cuda.CUDAGraph()
|
self._graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
||||||
with _maybe_pynccl():
|
with graph_capture_mode():
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
@ -1113,16 +1108,6 @@ class CUDAGraphRunner:
|
|||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _maybe_pynccl():
|
|
||||||
if pynccl_utils.is_initialized(
|
|
||||||
) and not custom_all_reduce.is_initialized():
|
|
||||||
with with_pynccl_for_all_reduce():
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _get_graph_batch_size(batch_size: int) -> int:
|
def _get_graph_batch_size(batch_size: int) -> int:
|
||||||
"""Returns the padded batch size given actual batch size.
|
"""Returns the padded batch size given actual batch size.
|
||||||
|
|
||||||
|
@ -11,9 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import (broadcast_tensor_dict,
|
from vllm.distributed import (broadcast_tensor_dict,
|
||||||
ensure_model_parallel_initialized,
|
ensure_model_parallel_initialized,
|
||||||
get_tensor_model_parallel_cpu_group,
|
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.distributed.device_communicators import pynccl_utils
|
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
init_custom_ar)
|
init_custom_ar)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -306,29 +304,10 @@ def init_worker_distributed_environment(
|
|||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
if pynccl_utils.is_initialized():
|
|
||||||
pynccl_world_size = pynccl_utils.get_world_size()
|
|
||||||
if pynccl_world_size != parallel_config.world_size:
|
|
||||||
raise RuntimeError(
|
|
||||||
"pynccl is already initialized but the pynccl world "
|
|
||||||
"size does not match parallel_config.world_size "
|
|
||||||
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
|
|
||||||
elif parallel_config.world_size > 1:
|
|
||||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
|
||||||
# is 1.
|
|
||||||
# NOTE(kaichao): By default, pynccl is initialized for tp group.
|
|
||||||
pynccl_utils.init_process_group(
|
|
||||||
group=get_tensor_model_parallel_cpu_group())
|
|
||||||
|
|
||||||
# Initialize a custom fast all-reduce implementation.
|
# Initialize a custom fast all-reduce implementation.
|
||||||
if not parallel_config.disable_custom_all_reduce:
|
if not parallel_config.disable_custom_all_reduce:
|
||||||
init_custom_ar()
|
init_custom_ar()
|
||||||
|
|
||||||
# A small all_reduce for warmup.
|
|
||||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
|
||||||
if pynccl_utils.is_initialized():
|
|
||||||
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
# Check if the GPU supports the dtype.
|
# Check if the GPU supports the dtype.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user