[Distributed] Add send and recv helpers (#5719)
This commit is contained in:
parent
6c916ac8a8
commit
5d4d90536f
@ -8,12 +8,11 @@ import pytest
|
|||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import (broadcast_tensor_dict,
|
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
|
|
||||||
from ..utils import (init_test_distributed_environment,
|
from ..utils import init_test_distributed_environment, multi_process_parallel
|
||||||
multi_process_tensor_parallel)
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
|
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||||
|
distributed_init_port: str):
|
||||||
|
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||||
|
distributed_init_port)
|
||||||
|
|
||||||
|
test_dict = {
|
||||||
|
# device tensor
|
||||||
|
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||||
|
# CPU tensor
|
||||||
|
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||||
|
"c": "test",
|
||||||
|
"d": [1, 2, 3],
|
||||||
|
"e": {
|
||||||
|
"a": 1,
|
||||||
|
"b": 2
|
||||||
|
},
|
||||||
|
# empty tensor
|
||||||
|
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if not get_pp_group().is_first_rank:
|
||||||
|
recv_dict = get_pp_group().recv_tensor_dict()
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
get_pp_group().send_tensor_dict(test_dict)
|
||||||
|
|
||||||
|
if not get_pp_group().is_first_rank:
|
||||||
|
assert len(recv_dict) == len(test_dict)
|
||||||
|
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||||
|
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||||
|
assert recv_dict["c"] == test_dict["c"]
|
||||||
|
assert recv_dict["d"] == test_dict["d"]
|
||||||
|
assert recv_dict["e"] == test_dict["e"]
|
||||||
|
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
|
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||||
|
distributed_init_port: str):
|
||||||
|
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||||
|
distributed_init_port)
|
||||||
|
|
||||||
|
size = 64
|
||||||
|
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
if not get_pp_group().is_first_rank:
|
||||||
|
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
get_pp_group().send(test_tensor)
|
||||||
|
|
||||||
|
if not get_pp_group().is_first_rank:
|
||||||
|
assert torch.allclose(test_tensor, recv_tensor)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
broadcast_tensor_dict_test_worker
|
broadcast_tensor_dict_test_worker
|
||||||
])
|
])
|
||||||
def test_multi_process_tensor_parallel(tp_size, test_target):
|
def test_multi_process_tensor_parallel(tp_size, test_target):
|
||||||
multi_process_tensor_parallel(tp_size, 1, test_target)
|
multi_process_parallel(tp_size, 1, test_target)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize("pp_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
|
||||||
|
def test_multi_process_pipeline_parallel(pp_size, test_target):
|
||||||
|
multi_process_parallel(1, pp_size, test_target)
|
||||||
|
@ -12,8 +12,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
|||||||
get_tp_group, graph_capture)
|
get_tp_group, graph_capture)
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (ensure_model_parallel_initialized,
|
||||||
init_test_distributed_environment,
|
init_test_distributed_environment, multi_process_parallel)
|
||||||
multi_process_tensor_parallel)
|
|
||||||
|
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||||
@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
|
|||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
|
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)
|
||||||
|
@ -168,9 +168,13 @@ def send_recv_worker_fn():
|
|||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
pynccl_comm.send(tensor)
|
pynccl_comm.send(tensor,
|
||||||
|
dst=(pynccl_comm.rank + 1) %
|
||||||
|
pynccl_comm.world_size)
|
||||||
else:
|
else:
|
||||||
pynccl_comm.recv(tensor)
|
pynccl_comm.recv(tensor,
|
||||||
|
src=(pynccl_comm.rank - 1) %
|
||||||
|
pynccl_comm.world_size)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|
||||||
@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
|
|||||||
device=device)
|
device=device)
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
pynccl_comm.send(tensor)
|
pynccl_comm.send(tensor,
|
||||||
|
dst=(pynccl_comm.rank + 1) %
|
||||||
|
pynccl_comm.world_size)
|
||||||
else:
|
else:
|
||||||
pynccl_comm.recv(tensor)
|
pynccl_comm.recv(tensor,
|
||||||
|
src=(pynccl_comm.rank - 1) %
|
||||||
|
pynccl_comm.world_size)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
if torch.distributed.get_rank() in [0, 2]:
|
if torch.distributed.get_rank() in [0, 2]:
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
@ -129,7 +129,7 @@ def init_test_distributed_environment(
|
|||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
|
|
||||||
|
|
||||||
def multi_process_tensor_parallel(
|
def multi_process_parallel(
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
test_target,
|
test_target,
|
||||||
|
@ -121,10 +121,7 @@ class PyNcclCommunicator:
|
|||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
cudaStream_t(stream.cuda_stream))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
def send(self,
|
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||||
tensor: torch.Tensor,
|
|
||||||
dst: Optional[int] = None,
|
|
||||||
stream=None):
|
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
assert tensor.device == self.device, (
|
assert tensor.device == self.device, (
|
||||||
@ -132,16 +129,11 @@ class PyNcclCommunicator:
|
|||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
if dst is None:
|
|
||||||
dst = (self.rank + 1) % self.world_size
|
|
||||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
def recv(self,
|
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||||
tensor: torch.Tensor,
|
|
||||||
src: Optional[int] = None,
|
|
||||||
stream=None):
|
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
assert tensor.device == self.device, (
|
assert tensor.device == self.device, (
|
||||||
@ -149,8 +141,6 @@ class PyNcclCommunicator:
|
|||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
if src is None:
|
|
||||||
src = (self.rank - 1) % self.world_size
|
|
||||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
|
|||||||
steps.
|
steps.
|
||||||
"""
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import pickle
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -28,6 +29,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
from torch.distributed import Backend, ProcessGroup
|
from torch.distributed import Backend, ProcessGroup
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -180,6 +182,16 @@ class GroupCoordinator:
|
|||||||
"""Return the global rank of the last process in the group"""
|
"""Return the global rank of the last process in the group"""
|
||||||
return self.ranks[-1]
|
return self.ranks[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_first_rank(self):
|
||||||
|
"""Return whether the caller is the first process in the group"""
|
||||||
|
return self.rank == self.first_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_last_rank(self):
|
||||||
|
"""Return whether the caller is the last process in the group"""
|
||||||
|
return self.rank == self.last_rank
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_rank(self):
|
def next_rank(self):
|
||||||
"""Return the global rank of the process that follows the caller"""
|
"""Return the global rank of the process that follows the caller"""
|
||||||
@ -374,6 +386,70 @@ class GroupCoordinator:
|
|||||||
group=self.device_group)
|
group=self.device_group)
|
||||||
return obj_list
|
return obj_list
|
||||||
|
|
||||||
|
def send_object(self, obj: Any, dst: int) -> None:
|
||||||
|
"""Send the input object list to the destination rank."""
|
||||||
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
|
||||||
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||||
|
|
||||||
|
assert dst != self.rank, (
|
||||||
|
"Invalid destination rank. Destination rank is the same "
|
||||||
|
"as the current rank.")
|
||||||
|
|
||||||
|
# Serialize object to tensor and get the size as well
|
||||||
|
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
||||||
|
|
||||||
|
size_tensor = torch.tensor([object_tensor.numel()],
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# Send object size
|
||||||
|
|
||||||
|
torch.distributed.send(size_tensor,
|
||||||
|
dst=self.ranks[dst],
|
||||||
|
group=self.cpu_group)
|
||||||
|
|
||||||
|
# Send object
|
||||||
|
torch.distributed.send(object_tensor,
|
||||||
|
dst=self.ranks[dst],
|
||||||
|
group=self.cpu_group)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def recv_object(self, src: int) -> Any:
|
||||||
|
"""Receive the input object list from the source rank."""
|
||||||
|
"""NOTE: `src` is the local rank of the source rank."""
|
||||||
|
|
||||||
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
|
assert src != self.rank, (
|
||||||
|
"Invalid source rank. Source rank is the same as the current rank."
|
||||||
|
)
|
||||||
|
|
||||||
|
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
||||||
|
|
||||||
|
# Receive object size
|
||||||
|
rank_size = torch.distributed.recv(size_tensor,
|
||||||
|
src=src,
|
||||||
|
group=self.cpu_group)
|
||||||
|
|
||||||
|
# Tensor to receive serialized objects into.
|
||||||
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||||
|
size_tensor.item(), # type: ignore[arg-type]
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
rank_object = torch.distributed.recv(object_tensor,
|
||||||
|
src=src,
|
||||||
|
group=self.cpu_group)
|
||||||
|
|
||||||
|
assert rank_object == rank_size, (
|
||||||
|
"Received object sender rank does not match the size sender rank.")
|
||||||
|
|
||||||
|
obj = pickle.loads(object_tensor.numpy().tobytes())
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
def broadcast_tensor_dict(
|
def broadcast_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||||
@ -459,6 +535,88 @@ class GroupCoordinator:
|
|||||||
async_handle.wait()
|
async_handle.wait()
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
|
||||||
|
def send_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
||||||
|
dst: Optional[int] = None
|
||||||
|
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||||
|
"""Send the input tensor dictionary.
|
||||||
|
NOTE: `dst` is the local rank of the source rank.
|
||||||
|
"""
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
group = self.device_group
|
||||||
|
metadata_group = self.cpu_group
|
||||||
|
|
||||||
|
if dst is None:
|
||||||
|
dst = self.next_rank
|
||||||
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||||
|
|
||||||
|
metadata_list: List[Tuple[Any, Any]] = []
|
||||||
|
assert isinstance(
|
||||||
|
tensor_dict,
|
||||||
|
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||||
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||||
|
# `metadata_list` lives in CPU memory.
|
||||||
|
# `send_object_list` has serialization & deserialization,
|
||||||
|
# all happening on CPU. Therefore, we can use the CPU group.
|
||||||
|
self.send_object(metadata_list, dst=dst)
|
||||||
|
for tensor in tensor_list:
|
||||||
|
if tensor.numel() == 0:
|
||||||
|
# Skip sending empty tensors.
|
||||||
|
continue
|
||||||
|
if tensor.is_cpu:
|
||||||
|
# use metadata_group for CPU tensors
|
||||||
|
torch.distributed.send(tensor, dst=dst, group=metadata_group)
|
||||||
|
else:
|
||||||
|
# use group for GPU tensors
|
||||||
|
torch.distributed.send(tensor, dst=dst, group=group)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def recv_tensor_dict(
|
||||||
|
self,
|
||||||
|
src: Optional[int] = None
|
||||||
|
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||||
|
"""Recv the input tensor dictionary.
|
||||||
|
NOTE: `src` is the local rank of the source rank.
|
||||||
|
"""
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
group = self.device_group
|
||||||
|
metadata_group = self.cpu_group
|
||||||
|
|
||||||
|
if src is None:
|
||||||
|
src = self.prev_rank
|
||||||
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
|
recv_metadata_list = self.recv_object(src=src)
|
||||||
|
tensor_dict = {}
|
||||||
|
for key, value in recv_metadata_list:
|
||||||
|
if isinstance(value, TensorMetadata):
|
||||||
|
tensor = torch.empty(value.size,
|
||||||
|
dtype=value.dtype,
|
||||||
|
device=value.device)
|
||||||
|
if tensor.numel() == 0:
|
||||||
|
# Skip broadcasting empty tensors.
|
||||||
|
tensor_dict[key] = tensor
|
||||||
|
continue
|
||||||
|
if tensor.is_cpu:
|
||||||
|
# use metadata_group for CPU tensors
|
||||||
|
torch.distributed.recv(tensor,
|
||||||
|
src=src,
|
||||||
|
group=metadata_group)
|
||||||
|
else:
|
||||||
|
# use group for GPU tensors
|
||||||
|
torch.distributed.recv(tensor, src=src, group=group)
|
||||||
|
tensor_dict[key] = tensor
|
||||||
|
else:
|
||||||
|
tensor_dict[key] = value
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
def barrier(self):
|
def barrier(self):
|
||||||
"""Barrier synchronization among the group.
|
"""Barrier synchronization among the group.
|
||||||
NOTE: don't use `device_group` here! `barrier` in NCCL is
|
NOTE: don't use `device_group` here! `barrier` in NCCL is
|
||||||
@ -468,6 +626,35 @@ class GroupCoordinator:
|
|||||||
"""
|
"""
|
||||||
torch.distributed.barrier(group=self.cpu_group)
|
torch.distributed.barrier(group=self.cpu_group)
|
||||||
|
|
||||||
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
if dst is None:
|
||||||
|
dst = self.next_rank
|
||||||
|
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||||
|
pynccl_comm.send(tensor, dst)
|
||||||
|
else:
|
||||||
|
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||||
|
|
||||||
|
def recv(self,
|
||||||
|
size: torch.Size,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
src: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""Receives a tensor from the src rank."""
|
||||||
|
"""NOTE: `src` is the local rank of the destination rank."""
|
||||||
|
if src is None:
|
||||||
|
src = self.prev_rank
|
||||||
|
|
||||||
|
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||||
|
pynccl_comm.recv(tensor, src)
|
||||||
|
else:
|
||||||
|
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||||
|
return tensor
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
if self.device_group is not None:
|
if self.device_group is not None:
|
||||||
torch.distributed.destroy_process_group(self.device_group)
|
torch.distributed.destroy_process_group(self.device_group)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user