[torch.compile] register allreduce operations as custom ops (#8526)
This commit is contained in:
parent
ee2bceaaa6
commit
99aa4eddaf
@ -163,13 +163,6 @@ steps:
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
|
||||
- label: torch compile integration test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
- label: Prefix Caching Test # 7min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
@ -348,7 +341,10 @@ steps:
|
||||
- vllm/executor/
|
||||
- vllm/model_executor/models/
|
||||
- tests/distributed/
|
||||
- vllm/compilation
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
|
@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
||||
bool full_nvlink) {
|
||||
auto inp_size = inp.numel() * inp.element_size();
|
||||
// custom allreduce requires input byte size to be multiples of 16
|
||||
if (inp_size % 16 != 0) return false;
|
||||
if (!_is_weak_contiguous(inp)) return false;
|
||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
|
||||
// performance improvement over NCCL.
|
||||
return false;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
cudaStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
|
@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
bool full_nvlink);
|
||||
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out);
|
||||
|
@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
"bool full_nvlink) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
custom_ar.def(
|
||||
"should_custom_ar(Tensor inp, int max_size, int world_size, "
|
||||
"bool full_nvlink) -> bool");
|
||||
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
|
||||
|
||||
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||
|
||||
|
0
tests/compile/__init__.py
Normal file
0
tests/compile/__init__.py
Normal file
@ -2,9 +2,20 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_full_graph(model):
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@fork_new_process_for_each_test
|
||||
def test_full_graph(model, tp_size):
|
||||
|
||||
# Skip the test if there are not enough CUDA devices.
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip("Not enough CUDA devices for the test.")
|
||||
|
||||
# make sure these models can be captured in full graph mode
|
||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
@ -17,7 +28,7 @@ def test_full_graph(model):
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model=model, enforce_eager=True)
|
||||
llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
|
@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
|
||||
offsets, rank, full_nvlink)
|
||||
|
||||
|
||||
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
|
||||
full_nvlink: bool) -> bool:
|
||||
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
|
||||
full_nvlink)
|
||||
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
|
||||
|
||||
|
@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (inp.storage().nbytes() -
|
||||
inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size())
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
@ -224,8 +230,19 @@ class CustomAllreduce:
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
return ops.should_custom_ar(inp, self.max_size, self.world_size,
|
||||
self.full_nvlink)
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
|
@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
|
||||
"""
|
||||
import contextlib
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -69,6 +70,58 @@ def _split_tensor_dict(
|
||||
return metadata_list, tensor_list
|
||||
|
||||
|
||||
_group_name_counter: Dict[str, int] = {}
|
||||
|
||||
|
||||
def _get_unique_name(name: str) -> str:
|
||||
"""Get a unique name for the group.
|
||||
Example:
|
||||
_get_unique_name("tp") -> "tp:0"
|
||||
_get_unique_name("tp") -> "tp:1"
|
||||
"""
|
||||
if name not in _group_name_counter:
|
||||
_group_name_counter[name] = 0
|
||||
newname = f"{name}:{_group_name_counter[name]}"
|
||||
_group_name_counter[name] += 1
|
||||
return newname
|
||||
|
||||
|
||||
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
|
||||
|
||||
|
||||
def _register_group(group: "GroupCoordinator") -> None:
|
||||
# looks like Python 3.8 does not understand `ReferenceType`
|
||||
_groups[group.unique_name] = weakref.ref(group) # type: ignore
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
|
||||
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
group._all_reduce(tensor)
|
||||
|
||||
|
||||
@inplace_all_reduce.register_fake
|
||||
def _(tensor: torch.Tensor, group_name: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
|
||||
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce(tensor)
|
||||
|
||||
|
||||
@outplace_all_reduce.register_fake
|
||||
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
|
||||
class GroupCoordinator:
|
||||
"""
|
||||
PyTorch ProcessGroup wrapper for a group of processes.
|
||||
@ -111,7 +164,11 @@ class GroupCoordinator:
|
||||
use_custom_allreduce: bool,
|
||||
use_tpu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
group_name = group_name or "anonymous"
|
||||
self.unique_name = _get_unique_name(group_name)
|
||||
_register_group(self)
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.local_rank = local_rank
|
||||
@ -149,28 +206,24 @@ class GroupCoordinator:
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator]
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.pynccl_comm = None
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce]
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.ca_comm = None
|
||||
|
||||
from vllm.distributed.device_communicators.tpu_communicator import (
|
||||
TpuCommunicator)
|
||||
self.tpu_communicator: Optional[TpuCommunicator]
|
||||
self.tpu_communicator: Optional[TpuCommunicator] = None
|
||||
if use_tpu_communicator and self.world_size > 1:
|
||||
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
||||
|
||||
@ -264,16 +317,46 @@ class GroupCoordinator:
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
User-facing all-reduce function before we actually call the
|
||||
all-reduce operation.
|
||||
|
||||
We need this because Dynamo does not support passing an arbitrary
|
||||
object (`self` in this case) to a custom op. We need to pass the
|
||||
group name as a string, and then look up the group coordinator from
|
||||
the group name, dispatch the all-reduce operation to the group
|
||||
coordinator.
|
||||
|
||||
In addition, PyTorch custom ops do not support mutation or returning
|
||||
a new tensor in the same op. So we need to figure out if the op is
|
||||
in-place or out-of-place ahead of time.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.tpu_communicator is not None and \
|
||||
not self.tpu_communicator.disabled:
|
||||
# TPU handles Dynamo with its own logic.
|
||||
return self._all_reduce(input_)
|
||||
|
||||
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
|
||||
return torch.ops.vllm.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name)
|
||||
else:
|
||||
torch.ops.vllm.inplace_all_reduce(input_,
|
||||
group_name=self.unique_name)
|
||||
return input_
|
||||
|
||||
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The actual all-reduce implementation.
|
||||
|
||||
NOTE: This operation will be applied in-place or out-of-place.
|
||||
Always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
ca_comm = self.ca_comm
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
use_pynccl=False,
|
||||
use_custom_allreduce=False,
|
||||
use_tpu_communicator=False,
|
||||
group_name="world",
|
||||
)
|
||||
|
||||
|
||||
@ -767,6 +851,7 @@ def init_model_parallel_group(
|
||||
backend: str,
|
||||
use_custom_allreduce: Optional[bool] = None,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
@ -778,6 +863,7 @@ def init_model_parallel_group(
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_tpu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
|
||||
@ -931,7 +1017,8 @@ def initialize_model_parallel(
|
||||
_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True)
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="tp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = (world_size //
|
||||
@ -947,7 +1034,8 @@ def initialize_model_parallel(
|
||||
_PP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False)
|
||||
use_custom_allreduce=False,
|
||||
group_name="pp")
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
|
Loading…
x
Reference in New Issue
Block a user