[Core][Distributed] enable multiple tp group (#4512)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
cf8cac8c70
commit
2a85f93007
@ -25,19 +25,24 @@ steps:
|
|||||||
- label: Distributed Comm Ops Test
|
- label: Distributed Comm Ops Test
|
||||||
command: pytest -v -s test_comm_ops.py
|
command: pytest -v -s test_comm_ops.py
|
||||||
working_dir: "/vllm-workspace/tests/distributed"
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
num_gpus: 2 # only support 1 or 2 for now.
|
num_gpus: 2
|
||||||
|
|
||||||
- label: Distributed Tests
|
- label: Distributed Tests
|
||||||
working_dir: "/vllm-workspace/tests/distributed"
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
num_gpus: 2 # only support 1 or 2 for now.
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s test_pynccl.py
|
|
||||||
- pytest -v -s test_pynccl_library.py
|
- pytest -v -s test_pynccl_library.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||||
|
|
||||||
|
- label: Distributed Tests (Multiple Groups)
|
||||||
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
|
num_gpus: 4
|
||||||
|
commands:
|
||||||
|
- pytest -v -s test_pynccl.py
|
||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||||
|
|
||||||
|
@ -45,6 +45,9 @@ steps:
|
|||||||
plugins:
|
plugins:
|
||||||
- kubernetes:
|
- kubernetes:
|
||||||
podSpec:
|
podSpec:
|
||||||
|
{% if step.num_gpus %}
|
||||||
|
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
|
||||||
|
{% endif %}
|
||||||
volumes:
|
volumes:
|
||||||
- name: dshm
|
- name: dshm
|
||||||
emptyDir:
|
emptyDir:
|
||||||
|
@ -58,6 +58,34 @@ def test_pynccl():
|
|||||||
distributed_run(worker_fn, 2)
|
distributed_run(worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def multiple_tp_worker_fn():
|
||||||
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
|
groups = [
|
||||||
|
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||||
|
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
|
||||||
|
]
|
||||||
|
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||||
|
comm = NCCLCommunicator(group=group, device=device)
|
||||||
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||||
|
# two groups can communicate independently
|
||||||
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
|
comm.all_reduce(tensor)
|
||||||
|
comm.all_reduce(tensor)
|
||||||
|
result = tensor.mean().cpu().item()
|
||||||
|
assert result == 4
|
||||||
|
else:
|
||||||
|
comm.all_reduce(tensor)
|
||||||
|
result = tensor.mean().cpu().item()
|
||||||
|
assert result == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
def test_pynccl_multiple_tp():
|
||||||
|
distributed_run(worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn_with_cudagraph():
|
def worker_fn_with_cudagraph():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -232,6 +232,7 @@ class NCCLCommunicator:
|
|||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
"NCCLCommunicator should be attached to a non-NCCL group.")
|
"NCCLCommunicator should be attached to a non-NCCL group.")
|
||||||
self.group = group
|
self.group = group
|
||||||
|
# note: this rank is the rank in the group
|
||||||
self.rank = dist.get_rank(group)
|
self.rank = dist.get_rank(group)
|
||||||
self.world_size = dist.get_world_size(group)
|
self.world_size = dist.get_world_size(group)
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
@ -239,7 +240,9 @@ class NCCLCommunicator:
|
|||||||
else:
|
else:
|
||||||
self.unique_id = NcclUniqueId()
|
self.unique_id = NcclUniqueId()
|
||||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||||
dist.broadcast(tensor, src=0, group=group)
|
ranks = dist.get_process_group_ranks(group)
|
||||||
|
# arg `src` in `broadcast` is the global rank
|
||||||
|
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||||
byte_list = tensor.tolist()
|
byte_list = tensor.tolist()
|
||||||
for i, byte in enumerate(byte_list):
|
for i, byte in enumerate(byte_list):
|
||||||
self.unique_id.internal[i] = byte
|
self.unique_id.internal[i] = byte
|
||||||
|
Loading…
x
Reference in New Issue
Block a user