[torch.compile] register allreduce operations as custom ops (#8526)
This commit is contained in:
parent
ee2bceaaa6
commit
99aa4eddaf
@ -163,13 +163,6 @@ steps:
|
|||||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
- python3 offline_inference_encoder_decoder.py
|
- python3 offline_inference_encoder_decoder.py
|
||||||
|
|
||||||
- label: torch compile integration test
|
|
||||||
source_file_dependencies:
|
|
||||||
- vllm/
|
|
||||||
commands:
|
|
||||||
- pytest -v -s ./compile/test_full_graph.py
|
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
|
||||||
|
|
||||||
- label: Prefix Caching Test # 7min
|
- label: Prefix Caching Test # 7min
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -348,7 +341,10 @@ steps:
|
|||||||
- vllm/executor/
|
- vllm/executor/
|
||||||
- vllm/model_executor/models/
|
- vllm/model_executor/models/
|
||||||
- tests/distributed/
|
- tests/distributed/
|
||||||
|
- vllm/compilation
|
||||||
commands:
|
commands:
|
||||||
|
- pytest -v -s ./compile/test_full_graph.py
|
||||||
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||||
# Avoid importing model tests that cause CUDA reinitialization error
|
# Avoid importing model tests that cause CUDA reinitialization error
|
||||||
|
@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
|||||||
t.numel() * t.element_size());
|
t.numel() * t.element_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
|
||||||
bool full_nvlink) {
|
|
||||||
auto inp_size = inp.numel() * inp.element_size();
|
|
||||||
// custom allreduce requires input byte size to be multiples of 16
|
|
||||||
if (inp_size % 16 != 0) return false;
|
|
||||||
if (!_is_weak_contiguous(inp)) return false;
|
|
||||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
|
||||||
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
|
|
||||||
// performance improvement over NCCL.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
|
@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
|||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets, int64_t rank,
|
const std::vector<int64_t>& offsets, int64_t rank,
|
||||||
bool full_nvlink);
|
bool full_nvlink);
|
||||||
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
|
||||||
bool full_nvlink);
|
|
||||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||||
torch::Tensor& out);
|
torch::Tensor& out);
|
||||||
|
@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|||||||
"bool full_nvlink) -> int");
|
"bool full_nvlink) -> int");
|
||||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||||
|
|
||||||
custom_ar.def(
|
|
||||||
"should_custom_ar(Tensor inp, int max_size, int world_size, "
|
|
||||||
"bool full_nvlink) -> bool");
|
|
||||||
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
|
|
||||||
|
|
||||||
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||||
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||||
|
|
||||||
|
0
tests/compile/__init__.py
Normal file
0
tests/compile/__init__.py
Normal file
@ -2,9 +2,20 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
|
from ..utils import fork_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||||
def test_full_graph(model):
|
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_full_graph(model, tp_size):
|
||||||
|
|
||||||
|
# Skip the test if there are not enough CUDA devices.
|
||||||
|
if cuda_device_count_stateless() < tp_size:
|
||||||
|
pytest.skip("Not enough CUDA devices for the test.")
|
||||||
|
|
||||||
# make sure these models can be captured in full graph mode
|
# make sure these models can be captured in full graph mode
|
||||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||||
@ -17,7 +28,7 @@ def test_full_graph(model):
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
llm = LLM(model=model, enforce_eager=True)
|
llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
|
|||||||
offsets, rank, full_nvlink)
|
offsets, rank, full_nvlink)
|
||||||
|
|
||||||
|
|
||||||
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
|
|
||||||
full_nvlink: bool) -> bool:
|
|
||||||
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
|
|
||||||
full_nvlink)
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||||
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
|
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
|
||||||
|
|
||||||
|
@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_weak_contiguous(inp: torch.Tensor):
|
||||||
|
return inp.is_contiguous() or (inp.storage().nbytes() -
|
||||||
|
inp.storage_offset() * inp.element_size()
|
||||||
|
== inp.numel() * inp.element_size())
|
||||||
|
|
||||||
|
|
||||||
class CustomAllreduce:
|
class CustomAllreduce:
|
||||||
|
|
||||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||||
@ -224,8 +230,19 @@ class CustomAllreduce:
|
|||||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||||
|
|
||||||
def should_custom_ar(self, inp: torch.Tensor):
|
def should_custom_ar(self, inp: torch.Tensor):
|
||||||
return ops.should_custom_ar(inp, self.max_size, self.world_size,
|
if self.disabled:
|
||||||
self.full_nvlink)
|
return False
|
||||||
|
inp_size = inp.numel() * inp.element_size()
|
||||||
|
# custom allreduce requires input byte size to be multiples of 16
|
||||||
|
if inp_size % 16 != 0:
|
||||||
|
return False
|
||||||
|
if not is_weak_contiguous(inp):
|
||||||
|
return False
|
||||||
|
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||||
|
# little performance improvement over NCCL.
|
||||||
|
if self.world_size == 2 or self.full_nvlink:
|
||||||
|
return inp_size < self.max_size
|
||||||
|
return False
|
||||||
|
|
||||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||||
# or, in the context of cuda graphs, register_graph_buffers
|
# or, in the context of cuda graphs, register_graph_buffers
|
||||||
|
@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
|
|||||||
"""
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
import pickle
|
import pickle
|
||||||
|
import weakref
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -69,6 +70,58 @@ def _split_tensor_dict(
|
|||||||
return metadata_list, tensor_list
|
return metadata_list, tensor_list
|
||||||
|
|
||||||
|
|
||||||
|
_group_name_counter: Dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unique_name(name: str) -> str:
|
||||||
|
"""Get a unique name for the group.
|
||||||
|
Example:
|
||||||
|
_get_unique_name("tp") -> "tp:0"
|
||||||
|
_get_unique_name("tp") -> "tp:1"
|
||||||
|
"""
|
||||||
|
if name not in _group_name_counter:
|
||||||
|
_group_name_counter[name] = 0
|
||||||
|
newname = f"{name}:{_group_name_counter[name]}"
|
||||||
|
_group_name_counter[name] += 1
|
||||||
|
return newname
|
||||||
|
|
||||||
|
|
||||||
|
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _register_group(group: "GroupCoordinator") -> None:
|
||||||
|
# looks like Python 3.8 does not understand `ReferenceType`
|
||||||
|
_groups[group.unique_name] = weakref.ref(group) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
|
||||||
|
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
group._all_reduce(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@inplace_all_reduce.register_fake
|
||||||
|
def _(tensor: torch.Tensor, group_name: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
|
||||||
|
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
return group._all_reduce(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@outplace_all_reduce.register_fake
|
||||||
|
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||||
|
return torch.empty_like(tensor)
|
||||||
|
|
||||||
|
|
||||||
class GroupCoordinator:
|
class GroupCoordinator:
|
||||||
"""
|
"""
|
||||||
PyTorch ProcessGroup wrapper for a group of processes.
|
PyTorch ProcessGroup wrapper for a group of processes.
|
||||||
@ -111,7 +164,11 @@ class GroupCoordinator:
|
|||||||
use_custom_allreduce: bool,
|
use_custom_allreduce: bool,
|
||||||
use_tpu_communicator: bool,
|
use_tpu_communicator: bool,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
|
group_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
group_name = group_name or "anonymous"
|
||||||
|
self.unique_name = _get_unique_name(group_name)
|
||||||
|
_register_group(self)
|
||||||
|
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
@ -149,28 +206,24 @@ class GroupCoordinator:
|
|||||||
from vllm.distributed.device_communicators.pynccl import (
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
PyNcclCommunicator)
|
PyNcclCommunicator)
|
||||||
|
|
||||||
self.pynccl_comm: Optional[PyNcclCommunicator]
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||||
if use_pynccl and self.world_size > 1:
|
if use_pynccl and self.world_size > 1:
|
||||||
self.pynccl_comm = PyNcclCommunicator(
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.pynccl_comm = None
|
|
||||||
|
|
||||||
self.ca_comm: Optional[CustomAllreduce]
|
self.ca_comm: Optional[CustomAllreduce] = None
|
||||||
if use_custom_allreduce and self.world_size > 1:
|
if use_custom_allreduce and self.world_size > 1:
|
||||||
# Initialize a custom fast all-reduce implementation.
|
# Initialize a custom fast all-reduce implementation.
|
||||||
self.ca_comm = CustomAllreduce(
|
self.ca_comm = CustomAllreduce(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.ca_comm = None
|
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.tpu_communicator import (
|
from vllm.distributed.device_communicators.tpu_communicator import (
|
||||||
TpuCommunicator)
|
TpuCommunicator)
|
||||||
self.tpu_communicator: Optional[TpuCommunicator]
|
self.tpu_communicator: Optional[TpuCommunicator] = None
|
||||||
if use_tpu_communicator and self.world_size > 1:
|
if use_tpu_communicator and self.world_size > 1:
|
||||||
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
||||||
|
|
||||||
@ -264,16 +317,46 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
User-facing all-reduce function before we actually call the
|
||||||
|
all-reduce operation.
|
||||||
|
|
||||||
|
We need this because Dynamo does not support passing an arbitrary
|
||||||
|
object (`self` in this case) to a custom op. We need to pass the
|
||||||
|
group name as a string, and then look up the group coordinator from
|
||||||
|
the group name, dispatch the all-reduce operation to the group
|
||||||
|
coordinator.
|
||||||
|
|
||||||
|
In addition, PyTorch custom ops do not support mutation or returning
|
||||||
|
a new tensor in the same op. So we need to figure out if the op is
|
||||||
|
in-place or out-of-place ahead of time.
|
||||||
|
"""
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if self.world_size == 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
|
if self.tpu_communicator is not None and \
|
||||||
|
not self.tpu_communicator.disabled:
|
||||||
|
# TPU handles Dynamo with its own logic.
|
||||||
|
return self._all_reduce(input_)
|
||||||
|
|
||||||
|
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
|
||||||
|
return torch.ops.vllm.outplace_all_reduce(
|
||||||
|
input_, group_name=self.unique_name)
|
||||||
|
else:
|
||||||
|
torch.ops.vllm.inplace_all_reduce(input_,
|
||||||
|
group_name=self.unique_name)
|
||||||
|
return input_
|
||||||
|
|
||||||
|
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
The actual all-reduce implementation.
|
||||||
|
|
||||||
NOTE: This operation will be applied in-place or out-of-place.
|
NOTE: This operation will be applied in-place or out-of-place.
|
||||||
Always assume this function modifies its input, but use the return
|
Always assume this function modifies its input, but use the return
|
||||||
value as the output.
|
value as the output.
|
||||||
"""
|
"""
|
||||||
ca_comm = self.ca_comm
|
ca_comm = self.ca_comm
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if self.world_size == 1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
# For TPUs, use TPU communicator.
|
# For TPUs, use TPU communicator.
|
||||||
tpu_comm = self.tpu_communicator
|
tpu_comm = self.tpu_communicator
|
||||||
if tpu_comm is not None and not tpu_comm.disabled:
|
if tpu_comm is not None and not tpu_comm.disabled:
|
||||||
@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|||||||
use_pynccl=False,
|
use_pynccl=False,
|
||||||
use_custom_allreduce=False,
|
use_custom_allreduce=False,
|
||||||
use_tpu_communicator=False,
|
use_tpu_communicator=False,
|
||||||
|
group_name="world",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -767,6 +851,7 @@ def init_model_parallel_group(
|
|||||||
backend: str,
|
backend: str,
|
||||||
use_custom_allreduce: Optional[bool] = None,
|
use_custom_allreduce: Optional[bool] = None,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
|
group_name: Optional[str] = None,
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
if use_custom_allreduce is None:
|
if use_custom_allreduce is None:
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
@ -778,6 +863,7 @@ def init_model_parallel_group(
|
|||||||
use_custom_allreduce=use_custom_allreduce,
|
use_custom_allreduce=use_custom_allreduce,
|
||||||
use_tpu_communicator=True,
|
use_tpu_communicator=True,
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
|
group_name=group_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -931,7 +1017,8 @@ def initialize_model_parallel(
|
|||||||
_TP = init_model_parallel_group(group_ranks,
|
_TP = init_model_parallel_group(group_ranks,
|
||||||
get_world_group().local_rank,
|
get_world_group().local_rank,
|
||||||
backend,
|
backend,
|
||||||
use_message_queue_broadcaster=True)
|
use_message_queue_broadcaster=True,
|
||||||
|
group_name="tp")
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
num_pipeline_model_parallel_groups: int = (world_size //
|
num_pipeline_model_parallel_groups: int = (world_size //
|
||||||
@ -947,7 +1034,8 @@ def initialize_model_parallel(
|
|||||||
_PP = init_model_parallel_group(group_ranks,
|
_PP = init_model_parallel_group(group_ranks,
|
||||||
get_world_group().local_rank,
|
get_world_group().local_rank,
|
||||||
backend,
|
backend,
|
||||||
use_custom_allreduce=False)
|
use_custom_allreduce=False,
|
||||||
|
group_name="pp")
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user