1196 lines
46 KiB
Python
1196 lines
46 KiB
Python
# Copyright 2023 The vLLM team.
|
|
# Adapted from
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
"""vLLM distributed state.
|
|
It takes over the control of the distributed environment from PyTorch.
|
|
The typical workflow is:
|
|
|
|
- call `init_distributed_environment` to initialize the distributed environment.
|
|
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
|
initialize the model parallel groups.
|
|
|
|
- any code dealing with the distributed stuff
|
|
|
|
- call `destroy_model_parallel` to destroy the model parallel groups.
|
|
- call `destroy_distributed_environment` to destroy the distributed environment.
|
|
|
|
If you only need to use the distributed environment without model/pipeline
|
|
parallelism, you can skip the model parallel initialization and destruction
|
|
steps.
|
|
"""
|
|
import contextlib
|
|
import pickle
|
|
import weakref
|
|
from collections import namedtuple
|
|
from contextlib import contextmanager, nullcontext
|
|
from dataclasses import dataclass
|
|
from multiprocessing import shared_memory
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.distributed
|
|
from torch.distributed import Backend, ProcessGroup
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
|
|
|
|
@dataclass
|
|
class GraphCaptureContext:
|
|
stream: torch.cuda.Stream
|
|
|
|
|
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|
|
|
|
|
def _split_tensor_dict(
|
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
|
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
|
"""Split the tensor dictionary into two parts:
|
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
|
by its metadata.
|
|
2. A list of tensors.
|
|
"""
|
|
metadata_list: List[Tuple[str, Any]] = []
|
|
tensor_list: List[torch.Tensor] = []
|
|
for key, value in tensor_dict.items():
|
|
if isinstance(value, torch.Tensor):
|
|
# Note: we cannot use `value.device` here,
|
|
# because it contains not only the device type but also the device
|
|
# index (e.g. "cuda:0"). We only need the device type.
|
|
# receiving side will set the device index.
|
|
device = value.device.type
|
|
metadata_list.append(
|
|
(key, TensorMetadata(device, value.dtype, value.size())))
|
|
tensor_list.append(value)
|
|
else:
|
|
metadata_list.append((key, value))
|
|
return metadata_list, tensor_list
|
|
|
|
|
|
_group_name_counter: Dict[str, int] = {}
|
|
|
|
|
|
def _get_unique_name(name: str) -> str:
|
|
"""Get a unique name for the group.
|
|
Example:
|
|
_get_unique_name("tp") -> "tp:0"
|
|
_get_unique_name("tp") -> "tp:1"
|
|
"""
|
|
if name not in _group_name_counter:
|
|
_group_name_counter[name] = 0
|
|
newname = f"{name}:{_group_name_counter[name]}"
|
|
_group_name_counter[name] += 1
|
|
return newname
|
|
|
|
|
|
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
|
|
|
|
|
|
def _register_group(group: "GroupCoordinator") -> None:
|
|
# looks like Python 3.8 does not understand `ReferenceType`
|
|
_groups[group.unique_name] = weakref.ref(group) # type: ignore
|
|
|
|
|
|
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
|
|
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
|
assert group_name in _groups, f"Group {group_name} is not found."
|
|
group = _groups[group_name]()
|
|
if group is None:
|
|
raise ValueError(f"Group {group_name} is destroyed.")
|
|
group._all_reduce(tensor)
|
|
|
|
|
|
@inplace_all_reduce.register_fake
|
|
def _(tensor: torch.Tensor, group_name: str) -> None:
|
|
return
|
|
|
|
|
|
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
|
|
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|
assert group_name in _groups, f"Group {group_name} is not found."
|
|
group = _groups[group_name]()
|
|
if group is None:
|
|
raise ValueError(f"Group {group_name} is destroyed.")
|
|
return group._all_reduce(tensor)
|
|
|
|
|
|
@outplace_all_reduce.register_fake
|
|
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|
return torch.empty_like(tensor)
|
|
|
|
|
|
class GroupCoordinator:
|
|
"""
|
|
PyTorch ProcessGroup wrapper for a group of processes.
|
|
PyTorch ProcessGroup is bound to one specific communication backend,
|
|
e.g. NCCL, Gloo, MPI, etc.
|
|
GroupCoordinator takes charge of all the communication operations among
|
|
the processes in the group. It can route the communication to
|
|
a specific implementation (e.g. switch allreduce implementation
|
|
based on the tensor size and cuda graph mode).
|
|
"""
|
|
|
|
# available attributes:
|
|
rank: int # global rank
|
|
ranks: List[int] # global ranks in the group
|
|
world_size: int # size of the group
|
|
# difference between `local_rank` and `rank_in_group`:
|
|
# if we have a group of size 4 across two nodes:
|
|
# Process | Node | Rank | Local Rank | Rank in Group
|
|
# 0 | 0 | 0 | 0 | 0
|
|
# 1 | 0 | 1 | 1 | 1
|
|
# 2 | 1 | 2 | 0 | 2
|
|
# 3 | 1 | 3 | 1 | 3
|
|
local_rank: int # local rank used to assign devices
|
|
rank_in_group: int # rank inside the group
|
|
cpu_group: ProcessGroup # group for CPU communication
|
|
device_group: ProcessGroup # group for device communication
|
|
use_pynccl: bool # a hint of whether to use PyNccl
|
|
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
|
# communicators are only created for world size > 1
|
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
|
|
|
def __init__(
|
|
self,
|
|
group_ranks: List[List[int]],
|
|
local_rank: int,
|
|
torch_distributed_backend: Union[str, Backend],
|
|
use_pynccl: bool,
|
|
use_custom_allreduce: bool,
|
|
use_tpu_communicator: bool,
|
|
use_message_queue_broadcaster: bool = False,
|
|
group_name: Optional[str] = None,
|
|
):
|
|
group_name = group_name or "anonymous"
|
|
self.unique_name = _get_unique_name(group_name)
|
|
_register_group(self)
|
|
|
|
self.rank = torch.distributed.get_rank()
|
|
self.local_rank = local_rank
|
|
self.device_group = None
|
|
self.cpu_group = None
|
|
|
|
for ranks in group_ranks:
|
|
device_group = torch.distributed.new_group(
|
|
ranks, backend=torch_distributed_backend)
|
|
# a group with `gloo` backend, to allow direct coordination between
|
|
# processes through the CPU.
|
|
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
|
if self.rank in ranks:
|
|
self.ranks = ranks
|
|
self.world_size = len(ranks)
|
|
self.rank_in_group = ranks.index(self.rank)
|
|
self.device_group = device_group
|
|
self.cpu_group = cpu_group
|
|
|
|
assert self.cpu_group is not None
|
|
assert self.device_group is not None
|
|
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device(f"cuda:{local_rank}")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
|
|
self.use_pynccl = use_pynccl
|
|
self.use_custom_allreduce = use_custom_allreduce
|
|
self.use_tpu_communicator = use_tpu_communicator
|
|
|
|
# lazy import to avoid documentation build error
|
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
|
CustomAllreduce)
|
|
from vllm.distributed.device_communicators.pynccl import (
|
|
PyNcclCommunicator)
|
|
|
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
|
if use_pynccl and self.world_size > 1:
|
|
self.pynccl_comm = PyNcclCommunicator(
|
|
group=self.cpu_group,
|
|
device=self.device,
|
|
)
|
|
|
|
self.ca_comm: Optional[CustomAllreduce] = None
|
|
if use_custom_allreduce and self.world_size > 1:
|
|
# Initialize a custom fast all-reduce implementation.
|
|
self.ca_comm = CustomAllreduce(
|
|
group=self.cpu_group,
|
|
device=self.device,
|
|
)
|
|
|
|
from vllm.distributed.device_communicators.tpu_communicator import (
|
|
TpuCommunicator)
|
|
self.tpu_communicator: Optional[TpuCommunicator] = None
|
|
if use_tpu_communicator and self.world_size > 1:
|
|
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
|
|
|
from vllm.distributed.device_communicators.shm_broadcast import (
|
|
MessageQueue)
|
|
self.mq_broadcaster: Optional[MessageQueue] = None
|
|
if use_message_queue_broadcaster and self.world_size > 1:
|
|
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
|
self.cpu_group, 1 << 22, 6)
|
|
|
|
@property
|
|
def first_rank(self):
|
|
"""Return the global rank of the first process in the group"""
|
|
return self.ranks[0]
|
|
|
|
@property
|
|
def last_rank(self):
|
|
"""Return the global rank of the last process in the group"""
|
|
return self.ranks[-1]
|
|
|
|
@property
|
|
def is_first_rank(self):
|
|
"""Return whether the caller is the first process in the group"""
|
|
return self.rank == self.first_rank
|
|
|
|
@property
|
|
def is_last_rank(self):
|
|
"""Return whether the caller is the last process in the group"""
|
|
return self.rank == self.last_rank
|
|
|
|
@property
|
|
def next_rank(self):
|
|
"""Return the global rank of the process that follows the caller"""
|
|
rank_in_group = self.rank_in_group
|
|
world_size = self.world_size
|
|
return self.ranks[(rank_in_group + 1) % world_size]
|
|
|
|
@property
|
|
def prev_rank(self):
|
|
"""Return the global rank of the process that precedes the caller"""
|
|
rank_in_group = self.rank_in_group
|
|
world_size = self.world_size
|
|
return self.ranks[(rank_in_group - 1) % world_size]
|
|
|
|
@contextmanager
|
|
def graph_capture(
|
|
self, graph_capture_context: Optional[GraphCaptureContext] = None):
|
|
if graph_capture_context is None:
|
|
stream = torch.cuda.Stream()
|
|
graph_capture_context = GraphCaptureContext(stream)
|
|
else:
|
|
stream = graph_capture_context.stream
|
|
|
|
ca_comm = self.ca_comm
|
|
maybe_ca_context = nullcontext(
|
|
) if ca_comm is None else ca_comm.capture()
|
|
|
|
# ensure all initialization operations complete before attempting to
|
|
# capture the graph on another stream
|
|
curr_stream = torch.cuda.current_stream()
|
|
if curr_stream != stream:
|
|
stream.wait_stream(curr_stream)
|
|
|
|
with torch.cuda.stream(stream), maybe_ca_context:
|
|
# In graph mode, we have to be very careful about the collective
|
|
# operations. The current status is:
|
|
# allreduce \ Mode | Eager | Graph |
|
|
# --------------------------------------------
|
|
# custom allreduce | enabled | enabled |
|
|
# PyNccl | disabled| enabled |
|
|
# torch.distributed | enabled | disabled|
|
|
#
|
|
# Note that custom allreduce will have a runtime check, if the
|
|
# tensor size is too large, it will fallback to the next
|
|
# available option.
|
|
# In summary: When using CUDA graph, we use
|
|
# either custom all-reduce kernel or pynccl. When not using
|
|
# CUDA graph, we use either custom all-reduce kernel or
|
|
# PyTorch NCCL. We always prioritize using custom all-reduce
|
|
# kernel but fall back to PyTorch or pynccl if it is
|
|
# disabled or not supported.
|
|
pynccl_comm = self.pynccl_comm
|
|
maybe_pynccl_context: Any
|
|
if not pynccl_comm:
|
|
maybe_pynccl_context = nullcontext()
|
|
else:
|
|
maybe_pynccl_context = pynccl_comm.change_state(
|
|
enable=True, stream=torch.cuda.current_stream())
|
|
with maybe_pynccl_context:
|
|
yield graph_capture_context
|
|
|
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
User-facing all-reduce function before we actually call the
|
|
all-reduce operation.
|
|
|
|
We need this because Dynamo does not support passing an arbitrary
|
|
object (`self` in this case) to a custom op. We need to pass the
|
|
group name as a string, and then look up the group coordinator from
|
|
the group name, dispatch the all-reduce operation to the group
|
|
coordinator.
|
|
|
|
In addition, PyTorch custom ops do not support mutation or returning
|
|
a new tensor in the same op. So we need to figure out if the op is
|
|
in-place or out-of-place ahead of time.
|
|
"""
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if self.world_size == 1:
|
|
return input_
|
|
|
|
if self.tpu_communicator is not None and \
|
|
not self.tpu_communicator.disabled:
|
|
# TPU handles Dynamo with its own logic.
|
|
return self._all_reduce(input_)
|
|
|
|
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
|
|
return torch.ops.vllm.outplace_all_reduce(
|
|
input_, group_name=self.unique_name)
|
|
else:
|
|
torch.ops.vllm.inplace_all_reduce(input_,
|
|
group_name=self.unique_name)
|
|
return input_
|
|
|
|
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
The actual all-reduce implementation.
|
|
|
|
NOTE: This operation will be applied in-place or out-of-place.
|
|
Always assume this function modifies its input, but use the return
|
|
value as the output.
|
|
"""
|
|
ca_comm = self.ca_comm
|
|
|
|
# For TPUs, use TPU communicator.
|
|
tpu_comm = self.tpu_communicator
|
|
if tpu_comm is not None and not tpu_comm.disabled:
|
|
return tpu_comm.all_reduce(input_)
|
|
|
|
if ca_comm is not None:
|
|
out = ca_comm.custom_all_reduce(input_)
|
|
if out is not None:
|
|
return out
|
|
pynccl_comm = self.pynccl_comm
|
|
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
|
pynccl_comm.all_reduce(input_)
|
|
elif input_.is_cpu:
|
|
import intel_extension_for_pytorch as ipex
|
|
ipex.distributed.all_reduce(input_, group=self.device_group)
|
|
else:
|
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
|
return input_
|
|
|
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
world_size = self.world_size
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
assert -input_.dim() <= dim < input_.dim(), (
|
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
|
|
# For TPUs, use TPU communicator.
|
|
tpu_comm = self.tpu_communicator
|
|
if tpu_comm is not None and not tpu_comm.disabled:
|
|
return tpu_comm.all_gather(input_, dim)
|
|
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
input_size = input_.size()
|
|
# Allocate output tensor.
|
|
output_tensor = torch.empty((world_size, ) + input_size,
|
|
dtype=input_.dtype,
|
|
device=input_.device)
|
|
# All-gather.
|
|
torch.distributed.all_gather_into_tensor(output_tensor,
|
|
input_,
|
|
group=self.device_group)
|
|
# Reshape
|
|
output_tensor = output_tensor.movedim(0, dim)
|
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
(world_size *
|
|
input_size[dim], ) +
|
|
input_size[dim + 1:])
|
|
return output_tensor
|
|
|
|
def gather(self,
|
|
input_: torch.Tensor,
|
|
dst: int = 0,
|
|
dim: int = -1) -> Optional[torch.Tensor]:
|
|
"""
|
|
NOTE: We assume that the input tensor is on the same device across
|
|
all the ranks.
|
|
NOTE: `dst` is the local rank of the destination rank.
|
|
"""
|
|
world_size = self.world_size
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
assert -input_.dim() <= dim < input_.dim(), (
|
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
# Allocate output tensor.
|
|
if self.rank_in_group == dst:
|
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
else:
|
|
gather_list = None
|
|
# Gather.
|
|
torch.distributed.gather(input_,
|
|
gather_list,
|
|
dst=self.ranks[dst],
|
|
group=self.device_group)
|
|
if self.rank_in_group == dst:
|
|
output_tensor = torch.cat(gather_list, dim=dim)
|
|
else:
|
|
output_tensor = None
|
|
return output_tensor
|
|
|
|
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
|
"""Broadcast the input tensor.
|
|
NOTE: `src` is the local rank of the source rank.
|
|
"""
|
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if self.world_size == 1:
|
|
return input_
|
|
# Broadcast.
|
|
torch.distributed.broadcast(input_,
|
|
src=self.ranks[src],
|
|
group=self.device_group)
|
|
return input_
|
|
|
|
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
|
|
"""Broadcast the input object.
|
|
NOTE: `src` is the local rank of the source rank.
|
|
"""
|
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if self.world_size == 1:
|
|
return obj
|
|
if self.mq_broadcaster is not None:
|
|
assert src == 0, "Message queue broadcaster only supports src=0"
|
|
return self.mq_broadcaster.broadcast_object(obj)
|
|
if self.rank_in_group == src:
|
|
torch.distributed.broadcast_object_list([obj],
|
|
src=self.ranks[src],
|
|
group=self.cpu_group)
|
|
return obj
|
|
else:
|
|
recv = [None]
|
|
torch.distributed.broadcast_object_list(recv,
|
|
src=self.ranks[src],
|
|
group=self.cpu_group)
|
|
return recv[0]
|
|
|
|
def broadcast_object_list(self,
|
|
obj_list: List[Any],
|
|
src: int = 0,
|
|
group: Optional[ProcessGroup] = None):
|
|
"""Broadcast the input object list.
|
|
NOTE: `src` is the local rank of the source rank.
|
|
"""
|
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if self.world_size == 1:
|
|
return obj_list
|
|
# Broadcast.
|
|
torch.distributed.broadcast_object_list(obj_list,
|
|
src=self.ranks[src],
|
|
group=self.device_group)
|
|
return obj_list
|
|
|
|
def send_object(self, obj: Any, dst: int) -> None:
|
|
"""Send the input object list to the destination rank."""
|
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
|
|
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
|
|
|
assert dst != self.rank_in_group, (
|
|
"Invalid destination rank. Destination rank is the same "
|
|
"as the current rank.")
|
|
|
|
# Serialize object to tensor and get the size as well
|
|
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
|
|
|
size_tensor = torch.tensor([object_tensor.numel()],
|
|
dtype=torch.long,
|
|
device="cpu")
|
|
|
|
# Send object size
|
|
|
|
torch.distributed.send(size_tensor,
|
|
dst=self.ranks[dst],
|
|
group=self.cpu_group)
|
|
|
|
# Send object
|
|
torch.distributed.send(object_tensor,
|
|
dst=self.ranks[dst],
|
|
group=self.cpu_group)
|
|
|
|
return None
|
|
|
|
def recv_object(self, src: int) -> Any:
|
|
"""Receive the input object list from the source rank."""
|
|
"""NOTE: `src` is the local rank of the source rank."""
|
|
|
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
assert src != self.rank_in_group, (
|
|
"Invalid source rank. Source rank is the same as the current rank."
|
|
)
|
|
|
|
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
|
|
|
# Receive object size
|
|
rank_size = torch.distributed.recv(size_tensor,
|
|
src=self.ranks[src],
|
|
group=self.cpu_group)
|
|
|
|
# Tensor to receive serialized objects into.
|
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
|
size_tensor.item(), # type: ignore[arg-type]
|
|
dtype=torch.uint8,
|
|
device="cpu")
|
|
|
|
rank_object = torch.distributed.recv(object_tensor,
|
|
src=self.ranks[src],
|
|
group=self.cpu_group)
|
|
|
|
assert rank_object == rank_size, (
|
|
"Received object sender rank does not match the size sender rank.")
|
|
|
|
obj = pickle.loads(object_tensor.numpy().tobytes())
|
|
|
|
return obj
|
|
|
|
def broadcast_tensor_dict(
|
|
self,
|
|
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
|
src: int = 0,
|
|
group: Optional[ProcessGroup] = None,
|
|
metadata_group: Optional[ProcessGroup] = None
|
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
|
"""Broadcast the input tensor dictionary.
|
|
NOTE: `src` is the local rank of the source rank.
|
|
"""
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if (not torch.distributed.is_initialized() or self.world_size == 1):
|
|
return tensor_dict
|
|
|
|
group = self.device_group
|
|
metadata_group = self.cpu_group
|
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
rank_in_group = self.rank_in_group
|
|
if rank_in_group == src:
|
|
metadata_list: List[Tuple[Any, Any]] = []
|
|
assert isinstance(
|
|
tensor_dict,
|
|
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
|
# `metadata_list` lives in CPU memory.
|
|
# `broadcast_object_list` has serialization & deserialization,
|
|
# all happening on CPU. Therefore, we can use the CPU group.
|
|
self.broadcast_object(metadata_list, src=src)
|
|
async_handles = []
|
|
for tensor in tensor_list:
|
|
if tensor.numel() == 0:
|
|
# Skip broadcasting empty tensors.
|
|
continue
|
|
if tensor.is_cpu:
|
|
# use metadata_group for CPU tensors
|
|
handle = torch.distributed.broadcast(tensor,
|
|
src=self.ranks[src],
|
|
group=metadata_group,
|
|
async_op=True)
|
|
else:
|
|
# use group for GPU tensors
|
|
handle = torch.distributed.broadcast(tensor,
|
|
src=self.ranks[src],
|
|
group=group,
|
|
async_op=True)
|
|
async_handles.append(handle)
|
|
for async_handle in async_handles:
|
|
async_handle.wait()
|
|
|
|
else:
|
|
metadata_list = self.broadcast_object(None, src=src)
|
|
tensor_dict = {}
|
|
async_handles = []
|
|
for key, value in metadata_list:
|
|
if isinstance(value, TensorMetadata):
|
|
tensor = torch.empty(value.size,
|
|
dtype=value.dtype,
|
|
device=value.device)
|
|
if tensor.numel() == 0:
|
|
# Skip broadcasting empty tensors.
|
|
tensor_dict[key] = tensor
|
|
continue
|
|
if tensor.is_cpu:
|
|
# use metadata_group for CPU tensors
|
|
handle = torch.distributed.broadcast(
|
|
tensor,
|
|
src=self.ranks[src],
|
|
group=metadata_group,
|
|
async_op=True)
|
|
else:
|
|
# use group for GPU tensors
|
|
handle = torch.distributed.broadcast(
|
|
tensor,
|
|
src=self.ranks[src],
|
|
group=group,
|
|
async_op=True)
|
|
async_handles.append(handle)
|
|
tensor_dict[key] = tensor
|
|
else:
|
|
tensor_dict[key] = value
|
|
for async_handle in async_handles:
|
|
async_handle.wait()
|
|
return tensor_dict
|
|
|
|
def send_tensor_dict(
|
|
self,
|
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
|
dst: Optional[int] = None,
|
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
|
"""Send the input tensor dictionary.
|
|
NOTE: `dst` is the local rank of the source rank.
|
|
"""
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if not torch.distributed.is_initialized() or self.world_size == 1:
|
|
return tensor_dict
|
|
|
|
all_gather_size = (1 if all_gather_group is None else
|
|
all_gather_group.world_size)
|
|
all_gather_rank = (0 if all_gather_group is None else
|
|
all_gather_group.rank_in_group)
|
|
|
|
group = self.device_group
|
|
metadata_group = self.cpu_group
|
|
|
|
if dst is None:
|
|
dst = (self.rank_in_group + 1) % self.world_size
|
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
|
|
|
metadata_list: List[Tuple[Any, Any]] = []
|
|
assert isinstance(
|
|
tensor_dict,
|
|
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
|
# `metadata_list` lives in CPU memory.
|
|
# `send_object_list` has serialization & deserialization,
|
|
# all happening on CPU. Therefore, we can use the CPU group.
|
|
self.send_object(metadata_list, dst=dst)
|
|
for tensor in tensor_list:
|
|
if tensor.numel() == 0:
|
|
# Skip sending empty tensors.
|
|
continue
|
|
|
|
# send-allgather: send only a slice, then do allgather.
|
|
if (all_gather_group is not None
|
|
and tensor.numel() % all_gather_size == 0):
|
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
|
|
|
if tensor.is_cpu:
|
|
# use metadata_group for CPU tensors
|
|
torch.distributed.send(tensor,
|
|
dst=self.ranks[dst],
|
|
group=metadata_group)
|
|
else:
|
|
# use group for GPU tensors
|
|
torch.distributed.send(tensor,
|
|
dst=self.ranks[dst],
|
|
group=group)
|
|
return None
|
|
|
|
def recv_tensor_dict(
|
|
self,
|
|
src: Optional[int] = None,
|
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
|
"""Recv the input tensor dictionary.
|
|
NOTE: `src` is the local rank of the source rank.
|
|
"""
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if not torch.distributed.is_initialized() or self.world_size == 1:
|
|
return None
|
|
|
|
all_gather_size = (1 if all_gather_group is None else
|
|
all_gather_group.world_size)
|
|
all_gather_rank = (0 if all_gather_group is None else
|
|
all_gather_group.rank_in_group)
|
|
|
|
group = self.device_group
|
|
metadata_group = self.cpu_group
|
|
|
|
if src is None:
|
|
src = (self.rank_in_group - 1) % self.world_size
|
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
recv_metadata_list = self.recv_object(src=src)
|
|
tensor_dict: Dict[str, Any] = {}
|
|
for key, value in recv_metadata_list:
|
|
if isinstance(value, TensorMetadata):
|
|
tensor = torch.empty(value.size,
|
|
dtype=value.dtype,
|
|
device=value.device)
|
|
if tensor.numel() == 0:
|
|
# Skip broadcasting empty tensors.
|
|
tensor_dict[key] = tensor
|
|
continue
|
|
|
|
# send-allgather: send only a slice, then do allgather.
|
|
use_all_gather = (all_gather_group is not None
|
|
and tensor.numel() % all_gather_size == 0)
|
|
|
|
if use_all_gather:
|
|
orig_shape = tensor.shape
|
|
tensor = tensor.reshape(all_gather_size,
|
|
-1)[all_gather_rank]
|
|
|
|
if tensor.is_cpu:
|
|
# use metadata_group for CPU tensors
|
|
torch.distributed.recv(tensor,
|
|
src=self.ranks[src],
|
|
group=metadata_group)
|
|
else:
|
|
# use group for GPU tensors
|
|
torch.distributed.recv(tensor,
|
|
src=self.ranks[src],
|
|
group=group)
|
|
if use_all_gather:
|
|
# do the allgather
|
|
tensor = all_gather_group.all_gather( # type: ignore
|
|
tensor, dim=0)
|
|
tensor = tensor.reshape(orig_shape)
|
|
|
|
tensor_dict[key] = tensor
|
|
else:
|
|
tensor_dict[key] = value
|
|
return tensor_dict
|
|
|
|
def barrier(self):
|
|
"""Barrier synchronization among the group.
|
|
NOTE: don't use `device_group` here! `barrier` in NCCL is
|
|
terrible because it is internally a broadcast operation with
|
|
secretly created GPU tensors. It is easy to mess up the current
|
|
device. Use the CPU group instead.
|
|
"""
|
|
torch.distributed.barrier(group=self.cpu_group)
|
|
|
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
|
if dst is None:
|
|
dst = (self.rank_in_group + 1) % self.world_size
|
|
|
|
pynccl_comm = self.pynccl_comm
|
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
|
pynccl_comm.send(tensor, dst)
|
|
else:
|
|
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
|
|
|
def recv(self,
|
|
size: torch.Size,
|
|
dtype: torch.dtype,
|
|
src: Optional[int] = None) -> torch.Tensor:
|
|
"""Receives a tensor from the source rank."""
|
|
"""NOTE: `src` is the local rank of the source rank."""
|
|
if src is None:
|
|
src = (self.rank_in_group - 1) % self.world_size
|
|
|
|
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
|
pynccl_comm = self.pynccl_comm
|
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
|
pynccl_comm.recv(tensor, src)
|
|
else:
|
|
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
|
return tensor
|
|
|
|
def destroy(self):
|
|
if self.device_group is not None:
|
|
torch.distributed.destroy_process_group(self.device_group)
|
|
self.device_group = None
|
|
if self.cpu_group is not None:
|
|
torch.distributed.destroy_process_group(self.cpu_group)
|
|
self.cpu_group = None
|
|
if self.pynccl_comm is not None:
|
|
self.pynccl_comm = None
|
|
if self.ca_comm is not None:
|
|
self.ca_comm = None
|
|
if self.mq_broadcaster is not None:
|
|
self.mq_broadcaster = None
|
|
|
|
|
|
_WORLD: Optional[GroupCoordinator] = None
|
|
|
|
|
|
def get_world_group() -> GroupCoordinator:
|
|
assert _WORLD is not None, ("world group is not initialized")
|
|
return _WORLD
|
|
|
|
|
|
def init_world_group(ranks: List[int], local_rank: int,
|
|
backend: str) -> GroupCoordinator:
|
|
return GroupCoordinator(
|
|
group_ranks=[ranks],
|
|
local_rank=local_rank,
|
|
torch_distributed_backend=backend,
|
|
use_pynccl=False,
|
|
use_custom_allreduce=False,
|
|
use_tpu_communicator=False,
|
|
group_name="world",
|
|
)
|
|
|
|
|
|
def init_model_parallel_group(
|
|
group_ranks: List[List[int]],
|
|
local_rank: int,
|
|
backend: str,
|
|
use_custom_allreduce: Optional[bool] = None,
|
|
use_message_queue_broadcaster: bool = False,
|
|
group_name: Optional[str] = None,
|
|
) -> GroupCoordinator:
|
|
if use_custom_allreduce is None:
|
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
|
return GroupCoordinator(
|
|
group_ranks=group_ranks,
|
|
local_rank=local_rank,
|
|
torch_distributed_backend=backend,
|
|
use_pynccl=True,
|
|
use_custom_allreduce=use_custom_allreduce,
|
|
use_tpu_communicator=True,
|
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
|
group_name=group_name,
|
|
)
|
|
|
|
|
|
_TP: Optional[GroupCoordinator] = None
|
|
|
|
|
|
def get_tp_group() -> GroupCoordinator:
|
|
assert _TP is not None, ("tensor model parallel group is not initialized")
|
|
return _TP
|
|
|
|
|
|
# kept for backward compatibility
|
|
get_tensor_model_parallel_group = get_tp_group
|
|
|
|
_PP: Optional[GroupCoordinator] = None
|
|
|
|
|
|
def get_pp_group() -> GroupCoordinator:
|
|
assert _PP is not None, (
|
|
"pipeline model parallel group is not initialized")
|
|
return _PP
|
|
|
|
|
|
# kept for backward compatibility
|
|
get_pipeline_model_parallel_group = get_pp_group
|
|
|
|
|
|
@contextmanager
|
|
def graph_capture():
|
|
"""
|
|
`graph_capture` is a context manager which should surround the code that
|
|
is capturing the CUDA graph. Its main purpose is to ensure that the
|
|
some operations will be run after the graph is captured, before the graph
|
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
necessary data for the graph capture. Currently, it only contains the
|
|
stream that the graph capture is running on. This stream is set to the
|
|
current CUDA stream when the context manager is entered and reset to the
|
|
default stream when the context manager is exited. This is to ensure that
|
|
the graph capture is running on a separate stream from the default stream,
|
|
in order to explicitly distinguish the kernels to capture
|
|
from other kernels possibly launched on background in the default stream.
|
|
"""
|
|
with get_tp_group().graph_capture() as context, get_pp_group(
|
|
).graph_capture(context):
|
|
yield context
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_ENABLE_CUSTOM_ALL_REDUCE = True
|
|
|
|
|
|
def set_custom_all_reduce(enable: bool):
|
|
global _ENABLE_CUSTOM_ALL_REDUCE
|
|
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
|
|
|
|
|
def init_distributed_environment(
|
|
world_size: int = -1,
|
|
rank: int = -1,
|
|
distributed_init_method: str = "env://",
|
|
local_rank: int = -1,
|
|
backend: str = "nccl",
|
|
):
|
|
logger.debug(
|
|
"world_size=%d rank=%d local_rank=%d "
|
|
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
|
|
distributed_init_method, backend)
|
|
if not torch.distributed.is_initialized():
|
|
assert distributed_init_method is not None, (
|
|
"distributed_init_method must be provided when initializing "
|
|
"distributed environment")
|
|
# this backend is used for WORLD
|
|
torch.distributed.init_process_group(
|
|
backend=backend,
|
|
init_method=distributed_init_method,
|
|
world_size=world_size,
|
|
rank=rank)
|
|
# set the local rank
|
|
# local_rank is not available in torch ProcessGroup,
|
|
# see https://github.com/pytorch/pytorch/issues/122816
|
|
if local_rank == -1:
|
|
# local rank not set, this usually happens in single-node
|
|
# setting, where we can use rank as local rank
|
|
if distributed_init_method == "env://":
|
|
local_rank = envs.LOCAL_RANK
|
|
else:
|
|
local_rank = rank
|
|
global _WORLD
|
|
if _WORLD is None:
|
|
ranks = list(range(torch.distributed.get_world_size()))
|
|
_WORLD = init_world_group(ranks, local_rank, backend)
|
|
else:
|
|
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
|
"world group already initialized with a different world size")
|
|
|
|
|
|
def initialize_model_parallel(
|
|
tensor_model_parallel_size: int = 1,
|
|
pipeline_model_parallel_size: int = 1,
|
|
backend: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Initialize model parallel groups.
|
|
|
|
Arguments:
|
|
tensor_model_parallel_size: number of GPUs used for tensor model
|
|
parallelism.
|
|
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
|
parallelism.
|
|
|
|
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
|
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
|
the model pipeline. The present function will
|
|
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
|
4 tensor model-parallel groups:
|
|
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
|
2 pipeline model-parallel groups:
|
|
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
|
Note that for efficiency, the caller should make sure adjacent ranks
|
|
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
|
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
|
ranks 8 to 15 belong to the second box.
|
|
"""
|
|
# Get world size and rank. Ensure some consistencies.
|
|
assert torch.distributed.is_initialized()
|
|
world_size: int = torch.distributed.get_world_size()
|
|
backend = backend or torch.distributed.get_backend(
|
|
get_world_group().device_group)
|
|
|
|
if (world_size !=
|
|
tensor_model_parallel_size * pipeline_model_parallel_size):
|
|
raise RuntimeError(
|
|
f"world_size ({world_size}) is not equal to "
|
|
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
|
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
|
|
|
# Build the tensor model-parallel groups.
|
|
num_tensor_model_parallel_groups: int = (world_size //
|
|
tensor_model_parallel_size)
|
|
global _TP
|
|
assert _TP is None, ("tensor model parallel group is already initialized")
|
|
group_ranks = []
|
|
for i in range(num_tensor_model_parallel_groups):
|
|
ranks = list(
|
|
range(i * tensor_model_parallel_size,
|
|
(i + 1) * tensor_model_parallel_size))
|
|
group_ranks.append(ranks)
|
|
|
|
# message queue broadcaster is only used in tensor model parallel group
|
|
_TP = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
use_message_queue_broadcaster=True,
|
|
group_name="tp")
|
|
|
|
# Build the pipeline model-parallel groups.
|
|
num_pipeline_model_parallel_groups: int = (world_size //
|
|
pipeline_model_parallel_size)
|
|
global _PP
|
|
assert _PP is None, (
|
|
"pipeline model parallel group is already initialized")
|
|
group_ranks = []
|
|
for i in range(num_pipeline_model_parallel_groups):
|
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
|
group_ranks.append(ranks)
|
|
# pipeline parallel does not need custom allreduce
|
|
_PP = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
use_custom_allreduce=False,
|
|
group_name="pp")
|
|
|
|
|
|
def ensure_model_parallel_initialized(
|
|
tensor_model_parallel_size: int,
|
|
pipeline_model_parallel_size: int,
|
|
backend: Optional[str] = None,
|
|
) -> None:
|
|
"""Helper to initialize model parallel groups if they are not initialized,
|
|
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
|
values if the model parallel groups are initialized.
|
|
"""
|
|
backend = backend or torch.distributed.get_backend(
|
|
get_world_group().device_group)
|
|
if not model_parallel_is_initialized():
|
|
initialize_model_parallel(tensor_model_parallel_size,
|
|
pipeline_model_parallel_size, backend)
|
|
return
|
|
|
|
assert (
|
|
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
|
|
), ("tensor parallel group already initialized, but of unexpected size: "
|
|
f"{get_tensor_model_parallel_world_size()=} vs. "
|
|
f"{tensor_model_parallel_size=}")
|
|
pp_world_size = get_pp_group().world_size
|
|
assert (pp_world_size == pipeline_model_parallel_size), (
|
|
"pipeline parallel group already initialized, but of unexpected size: "
|
|
f"{pp_world_size=} vs. "
|
|
f"{pipeline_model_parallel_size=}")
|
|
|
|
|
|
def model_parallel_is_initialized():
|
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
|
return (_TP is not None and _PP is not None)
|
|
|
|
|
|
_TP_STATE_PATCHED = False
|
|
|
|
|
|
@contextmanager
|
|
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
"""Patch the tp group temporarily until this function ends.
|
|
|
|
This method is for draft workers of speculative decoding to run draft model
|
|
with different tp degree from that of target model workers.
|
|
|
|
Args:
|
|
tp_group (GroupCoordinator): the tp group coordinator
|
|
"""
|
|
global _TP_STATE_PATCHED
|
|
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
|
|
|
|
_TP_STATE_PATCHED = True
|
|
old_tp_group = get_tp_group()
|
|
global _TP
|
|
_TP = tp_group
|
|
try:
|
|
yield
|
|
finally:
|
|
# restore the original state
|
|
_TP_STATE_PATCHED = False
|
|
_TP = old_tp_group
|
|
|
|
|
|
def get_tensor_model_parallel_world_size():
|
|
"""Return world size for the tensor model parallel group."""
|
|
return get_tp_group().world_size
|
|
|
|
|
|
def get_tensor_model_parallel_rank():
|
|
"""Return my rank for the tensor model parallel group."""
|
|
return get_tp_group().rank_in_group
|
|
|
|
|
|
def destroy_model_parallel():
|
|
"""Set the groups to none and destroy them."""
|
|
global _TP
|
|
if _TP:
|
|
_TP.destroy()
|
|
_TP = None
|
|
|
|
global _PP
|
|
if _PP:
|
|
_PP.destroy()
|
|
_PP = None
|
|
|
|
|
|
def destroy_distributed_environment():
|
|
global _WORLD
|
|
if _WORLD:
|
|
_WORLD.destroy()
|
|
_WORLD = None
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
|
"""
|
|
This is a collective operation that returns if each rank is in the same node
|
|
as the source rank. It tests if processes are attached to the same
|
|
memory system (shared access to shared memory).
|
|
"""
|
|
assert torch.distributed.get_backend(
|
|
pg) != torch.distributed.Backend.NCCL, (
|
|
"in_the_same_node_as should be tested with a non-NCCL group.")
|
|
# local rank inside the group
|
|
rank = torch.distributed.get_rank(group=pg)
|
|
world_size = torch.distributed.get_world_size(group=pg)
|
|
|
|
# local tensor in each process to store the result
|
|
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
|
|
|
# global ranks of the processes in the group
|
|
ranks = torch.distributed.get_process_group_ranks(pg)
|
|
|
|
magic_message = b"magic_message"
|
|
shm = None
|
|
|
|
try:
|
|
with contextlib.suppress(OSError):
|
|
if rank == source_rank:
|
|
# create a shared memory segment
|
|
shm = shared_memory.SharedMemory(create=True, size=128)
|
|
shm.buf[:len(magic_message)] = magic_message
|
|
torch.distributed.broadcast_object_list([shm.name],
|
|
src=ranks[source_rank],
|
|
group=pg)
|
|
is_in_the_same_node[rank] = 1
|
|
else:
|
|
# try to open the shared memory segment
|
|
recv = [None]
|
|
torch.distributed.broadcast_object_list(recv,
|
|
src=ranks[source_rank],
|
|
group=pg)
|
|
name = recv[0]
|
|
# fix to https://stackoverflow.com/q/62748654/9191338
|
|
# Python incorrectly tracks shared memory even if it is not
|
|
# created by the process. The following patch is a workaround.
|
|
with patch("multiprocessing.resource_tracker.register",
|
|
lambda *args, **kwargs: None):
|
|
shm = shared_memory.SharedMemory(name=name)
|
|
if shm.buf[:len(magic_message)] == magic_message:
|
|
is_in_the_same_node[rank] = 1
|
|
except Exception as e:
|
|
logger.error("Error ignored in is_in_the_same_node: %s", e)
|
|
finally:
|
|
if shm:
|
|
shm.close()
|
|
|
|
torch.distributed.barrier(group=pg)
|
|
|
|
# clean up the shared memory segment
|
|
with contextlib.suppress(OSError):
|
|
if rank == source_rank and shm:
|
|
shm.unlink()
|
|
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
|
|
|
return [x == 1 for x in is_in_the_same_node.tolist()]
|