[core][distributed] zmq fallback for broadcasting large objects (#6183)
[core][distributed] add zmq fallback for broadcasting large objects (#6183)
This commit is contained in:
parent
2416b26e11
commit
da78caecfa
@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1
|
||||
outlines >= 0.0.43 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
pyzmq
|
||||
|
@ -2,10 +2,11 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
|
||||
torch.distributed.init_process_group(backend="gloo")
|
||||
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
|
||||
test_result = all(
|
||||
in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0))
|
||||
|
||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||
|
@ -6,8 +6,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
ShmRingBuffer, ShmRingBufferIO)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@ -56,8 +55,8 @@ def worker_fn_wrapper(fn):
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
writer_rank = 2
|
||||
broadcaster = ShmRingBufferIO.create_from_process_group(
|
||||
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
|
||||
broadcaster = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank)
|
||||
if dist.get_rank() == writer_rank:
|
||||
seed = random.randint(0, 1000)
|
||||
dist.broadcast_object_list([seed], writer_rank)
|
||||
@ -87,13 +86,3 @@ def worker_fn():
|
||||
|
||||
def test_shm_broadcast():
|
||||
distributed_run(worker_fn, 4)
|
||||
|
||||
|
||||
def test_singe_process():
|
||||
buffer = ShmRingBuffer(1, 1024, 4)
|
||||
reader = ShmRingBufferIO(buffer, reader_rank=0)
|
||||
writer = ShmRingBufferIO(buffer, reader_rank=-1)
|
||||
writer.enqueue([0])
|
||||
writer.enqueue([1])
|
||||
assert reader.dequeue() == [0]
|
||||
assert reader.dequeue() == [1]
|
||||
|
@ -9,7 +9,7 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
||||
|
||||
@ -64,7 +64,7 @@ class CustomAllreduce:
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
if not is_in_the_same_node(group):
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
|
@ -1,16 +1,19 @@
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
@ -135,18 +138,183 @@ class ShmRingBuffer:
|
||||
yield buf
|
||||
|
||||
|
||||
class ShmRingBufferIO:
|
||||
@dataclass
|
||||
class Handle:
|
||||
connect_ip: str
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
|
||||
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
|
||||
self.buffer = buffer
|
||||
self.reader_rank = reader_rank
|
||||
self._is_writer = self.reader_rank == -1
|
||||
self._is_reader = not self._is_writer
|
||||
if self._is_reader:
|
||||
assert 0 <= self.reader_rank < buffer.n_reader, \
|
||||
(f"Invalid reader rank {self.reader_rank} for buffer"
|
||||
f" created with {buffer.n_reader} readers")
|
||||
self.current_idx = 0
|
||||
buffer: Optional[ShmRingBuffer] = None
|
||||
local_subscribe_port: Optional[int] = None
|
||||
local_sync_port: Optional[int] = None
|
||||
remote_subscribe_port: Optional[int] = None
|
||||
remote_sync_port: Optional[int] = None
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[List[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
if connect_ip is None:
|
||||
connect_ip = get_ip()
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
self.local_socket = context.socket(PUB)
|
||||
local_subscribe_port = get_open_port()
|
||||
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
|
||||
|
||||
self.local_sync_socket = context.socket(REP)
|
||||
local_sync_port = get_open_port()
|
||||
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
|
||||
self.current_idx = 0
|
||||
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_port = None
|
||||
local_sync_port = None
|
||||
self.local_socket = None
|
||||
self.local_sync_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
self.remote_socket = context.socket(PUB)
|
||||
remote_subscribe_port = get_open_port()
|
||||
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
|
||||
|
||||
self.remote_sync_socket = context.socket(REP)
|
||||
remote_sync_port = get_open_port()
|
||||
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
|
||||
else:
|
||||
remote_subscribe_port = None
|
||||
remote_sync_port = None
|
||||
self.remote_socket = None
|
||||
self.remote_sync_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.handle = Handle(
|
||||
connect_ip=connect_ip,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer=self.buffer,
|
||||
local_subscribe_port=local_subscribe_port,
|
||||
local_sync_port=local_sync_port,
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
remote_sync_port=remote_sync_port,
|
||||
)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||
self = MessageQueue.__new__(MessageQueue)
|
||||
self.handle = handle
|
||||
self._is_writer = False
|
||||
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer is not None
|
||||
self.buffer = handle.buffer
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
self.local_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
|
||||
|
||||
self.local_sync_socket = context.socket(REQ)
|
||||
self.local_sync_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
|
||||
|
||||
self.remote_socket = None
|
||||
self.remote_sync_socket = None
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
self.local_reader_rank = -1
|
||||
self._is_local_reader = False
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
self.local_sync_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
self.remote_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
|
||||
|
||||
self.remote_sync_socket = context.socket(REQ)
|
||||
self.remote_sync_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
"""This is a collective operation. All processes (including the
|
||||
readers and the writer) should call this function.
|
||||
"""
|
||||
if self._is_writer:
|
||||
# wait for all readers to connect
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
recv = self.local_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
self.local_sync_socket.send(b"READY")
|
||||
if self.n_local_reader > 0:
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
recv = self.remote_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
self.remote_sync_socket.send(b"READY")
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
self.local_sync_socket.send(b"READY")
|
||||
recv = self.local_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
self.remote_sync_socket.send(b"READY")
|
||||
recv = self.remote_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self):
|
||||
@ -201,12 +369,12 @@ class ShmRingBufferIO:
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self):
|
||||
assert self._is_reader, "Only readers can acquire read"
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_flag = metadata_buffer[self.reader_rank + 1]
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
@ -236,7 +404,7 @@ class ShmRingBufferIO:
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.reader_rank + 1] = 1
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
break
|
||||
@ -244,21 +412,36 @@ class ShmRingBufferIO:
|
||||
def enqueue(self, obj):
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if len(serialized_obj) > self.buffer.max_chunk_bytes:
|
||||
raise RuntimeError(
|
||||
f"{len(serialized_obj)=} larger than the allowed value "
|
||||
f"{self.buffer.max_chunk_bytes},"
|
||||
"Please increase the max_chunk_bytes parameter.")
|
||||
with self.acquire_write() as buf:
|
||||
buf[:len(serialized_obj)] = serialized_obj
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write() as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
else:
|
||||
with self.acquire_write() as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self):
|
||||
assert self._is_reader, "Only readers can dequeue"
|
||||
with self.acquire_read() as buf:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format itself contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf)
|
||||
if self._is_local_reader:
|
||||
overflow = False
|
||||
with self.acquire_read() as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf[1:])
|
||||
if overflow:
|
||||
recv = self.local_socket.recv()
|
||||
obj = pickle.loads(recv)
|
||||
elif self._is_remote_reader:
|
||||
recv = self.remote_socket.recv()
|
||||
obj = pickle.loads(recv)
|
||||
else:
|
||||
raise RuntimeError("Only readers can dequeue")
|
||||
return obj
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
@ -272,24 +455,36 @@ class ShmRingBufferIO:
|
||||
def create_from_process_group(pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "ShmRingBufferIO":
|
||||
writer_rank=0) -> "MessageQueue":
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
ranks_inside_group = list(range(group_world_size))
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
buffer: ShmRingBuffer
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
|
||||
dist.broadcast_object_list([buffer],
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
return ShmRingBufferIO(buffer, -1)
|
||||
else:
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
buffer = recv[0] # type: ignore
|
||||
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
|
||||
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
|
||||
handle = recv[0] # type: ignore
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
|
@ -124,7 +124,7 @@ class GroupCoordinator:
|
||||
# communicators are only created for world size > 1
|
||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||
shm_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -133,6 +133,7 @@ class GroupCoordinator:
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_pynccl: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
):
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
@ -190,10 +191,10 @@ class GroupCoordinator:
|
||||
self.ca_comm = None
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
ShmRingBufferIO)
|
||||
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
|
||||
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
|
||||
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
|
||||
MessageQueue)
|
||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||
if use_message_queue_broadcaster and self.world_size > 1:
|
||||
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
||||
self.cpu_group, 1 << 22, 6)
|
||||
|
||||
@property
|
||||
@ -377,9 +378,9 @@ class GroupCoordinator:
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return obj
|
||||
if self.shm_broadcaster is not None:
|
||||
assert src == 0, "Shared memory broadcaster only supports src=0"
|
||||
return self.shm_broadcaster.broadcast_object(obj)
|
||||
if self.mq_broadcaster is not None:
|
||||
assert src == 0, "Message queue broadcaster only supports src=0"
|
||||
return self.mq_broadcaster.broadcast_object(obj)
|
||||
if self.rank_in_group == src:
|
||||
torch.distributed.broadcast_object_list([obj],
|
||||
src=self.ranks[src],
|
||||
@ -696,8 +697,8 @@ class GroupCoordinator:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.shm_broadcaster is not None:
|
||||
self.shm_broadcaster = None
|
||||
if self.mq_broadcaster is not None:
|
||||
self.mq_broadcaster = None
|
||||
|
||||
|
||||
_WORLD: Optional[GroupCoordinator] = None
|
||||
@ -720,10 +721,12 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
|
||||
|
||||
def init_model_parallel_group(
|
||||
group_ranks: List[List[int]],
|
||||
local_rank: int,
|
||||
backend: str,
|
||||
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
|
||||
group_ranks: List[List[int]],
|
||||
local_rank: int,
|
||||
backend: str,
|
||||
use_custom_allreduce: Optional[bool] = None,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
return GroupCoordinator(
|
||||
@ -732,6 +735,7 @@ def init_model_parallel_group(
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=True,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
)
|
||||
|
||||
|
||||
@ -880,8 +884,12 @@ def initialize_model_parallel(
|
||||
range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank, backend)
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = (world_size //
|
||||
@ -993,15 +1001,15 @@ def destroy_distributed_environment():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def is_in_the_same_node(pg: ProcessGroup):
|
||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
"""
|
||||
This is a collective operation that checks if all processes in the group
|
||||
are in the same node. It tests if all processes are attached to the same
|
||||
This is a collective operation that returns if each rank is in the same node
|
||||
as the source rank. It tests if processes are attached to the same
|
||||
memory system (shared access to shared memory).
|
||||
"""
|
||||
assert torch.distributed.get_backend(
|
||||
pg) != torch.distributed.Backend.NCCL, (
|
||||
"is_in_the_same_node should be tested with a non-NCCL group.")
|
||||
"in_the_same_node_as should be tested with a non-NCCL group.")
|
||||
# local rank inside the group
|
||||
rank = torch.distributed.get_rank(group=pg)
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
@ -1017,19 +1025,19 @@ def is_in_the_same_node(pg: ProcessGroup):
|
||||
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
if rank == 0:
|
||||
if rank == source_rank:
|
||||
# create a shared memory segment
|
||||
shm = shared_memory.SharedMemory(create=True, size=128)
|
||||
shm.buf[:len(magic_message)] = magic_message
|
||||
torch.distributed.broadcast_object_list([shm.name],
|
||||
src=ranks[0],
|
||||
src=ranks[source_rank],
|
||||
group=pg)
|
||||
is_in_the_same_node[0] = 1
|
||||
is_in_the_same_node[rank] = 1
|
||||
else:
|
||||
# try to open the shared memory segment
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(recv,
|
||||
src=ranks[0],
|
||||
src=ranks[source_rank],
|
||||
group=pg)
|
||||
name = recv[0]
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
@ -1050,8 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup):
|
||||
|
||||
# clean up the shared memory segment
|
||||
with contextlib.suppress(OSError):
|
||||
if rank == 0 and shm:
|
||||
if rank == source_rank and shm:
|
||||
shm.unlink()
|
||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||
|
||||
return is_in_the_same_node.sum().item() == world_size
|
||||
return [x == 1 for x in is_in_the_same_node.tolist()]
|
||||
|
Loading…
x
Reference in New Issue
Block a user