[torch.compile] register allreduce operations as custom ops (#8526)

This commit is contained in:
youkaichao 2024-09-16 22:57:57 -07:00 committed by GitHub
parent ee2bceaaa6
commit 99aa4eddaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 137 additions and 50 deletions

View File

@ -163,13 +163,6 @@ steps:
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py - python3 offline_inference_encoder_decoder.py
- label: torch compile integration test
source_file_dependencies:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_wrapper.py
- label: Prefix Caching Test # 7min - label: Prefix Caching Test # 7min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
@ -348,7 +341,10 @@ steps:
- vllm/executor/ - vllm/executor/
- vllm/model_executor/models/ - vllm/model_executor/models/
- tests/distributed/ - tests/distributed/
- vllm/compilation
commands: commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error

View File

@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16
if (inp_size % 16 != 0) return false;
if (!_is_weak_contiguous(inp)) return false;
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return false;
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) { cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);

View File

@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank, const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink); bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out); torch::Tensor& out);

View File

@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"bool full_nvlink) -> int"); "bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);

View File

View File

@ -2,9 +2,20 @@ import os
import pytest import pytest
from vllm.utils import cuda_device_count_stateless
from ..utils import fork_new_process_for_each_test
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model): @pytest.mark.parametrize("tp_size", [1, 2])
@fork_new_process_for_each_test
def test_full_graph(model, tp_size):
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
# make sure these models can be captured in full graph mode # make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
@ -17,7 +28,7 @@ def test_full_graph(model):
"The future of AI is", "The future of AI is",
] ]
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model, enforce_eager=True) llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
offsets, rank, full_nvlink) offsets, rank, full_nvlink)
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

View File

@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
@ -224,8 +230,19 @@ class CustomAllreduce:
ops.register_graph_buffers(self._ptr, handles, offsets) ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor): def should_custom_ar(self, inp: torch.Tensor):
return ops.should_custom_ar(inp, self.max_size, self.world_size, if self.disabled:
self.full_nvlink) return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
# all reduce, assuming inp tensor is IPC registered with register_buffer, # all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers # or, in the context of cuda graphs, register_graph_buffers

View File

@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
""" """
import contextlib import contextlib
import pickle import pickle
import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -69,6 +70,58 @@ def _split_tensor_dict(
return metadata_list, tensor_list return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {}
def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
def _register_group(group: "GroupCoordinator") -> None:
# looks like Python 3.8 does not understand `ReferenceType`
_groups[group.unique_name] = weakref.ref(group) # type: ignore
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce(tensor)
@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
return
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce(tensor)
@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
class GroupCoordinator: class GroupCoordinator:
""" """
PyTorch ProcessGroup wrapper for a group of processes. PyTorch ProcessGroup wrapper for a group of processes.
@ -111,7 +164,11 @@ class GroupCoordinator:
use_custom_allreduce: bool, use_custom_allreduce: bool,
use_tpu_communicator: bool, use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
): ):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = torch.distributed.get_rank() self.rank = torch.distributed.get_rank()
self.local_rank = local_rank self.local_rank = local_rank
@ -149,28 +206,24 @@ class GroupCoordinator:
from vllm.distributed.device_communicators.pynccl import ( from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator) PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
) )
else:
self.pynccl_comm = None
self.ca_comm: Optional[CustomAllreduce] self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
) )
else:
self.ca_comm = None
from vllm.distributed.device_communicators.tpu_communicator import ( from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator) TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator] self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1: if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group) self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
@ -264,16 +317,46 @@ class GroupCoordinator:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
""" """
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place. NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return Always assume this function modifies its input, but use the return
value as the output. value as the output.
""" """
ca_comm = self.ca_comm ca_comm = self.ca_comm
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# For TPUs, use TPU communicator. # For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled: if tpu_comm is not None and not tpu_comm.disabled:
@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl=False, use_pynccl=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_tpu_communicator=False, use_tpu_communicator=False,
group_name="world",
) )
@ -767,6 +851,7 @@ def init_model_parallel_group(
backend: str, backend: str,
use_custom_allreduce: Optional[bool] = None, use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
if use_custom_allreduce is None: if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
@ -778,6 +863,7 @@ def init_model_parallel_group(
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True, use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster, use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
) )
@ -931,7 +1017,8 @@ def initialize_model_parallel(
_TP = init_model_parallel_group(group_ranks, _TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_message_queue_broadcaster=True) use_message_queue_broadcaster=True,
group_name="tp")
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size // num_pipeline_model_parallel_groups: int = (world_size //
@ -947,7 +1034,8 @@ def initialize_model_parallel(
_PP = init_model_parallel_group(group_ranks, _PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False) use_custom_allreduce=False,
group_name="pp")
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(