diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 74b287c7..00fed96c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -463,6 +463,7 @@ steps: - vllm/worker/worker.py - vllm/worker/model_runner.py commands: + - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' diff --git a/examples/offline_inference/torchrun_example.py b/examples/offline_inference/torchrun_example.py new file mode 100644 index 00000000..b6de73eb --- /dev/null +++ b/examples/offline_inference/torchrun_example.py @@ -0,0 +1,64 @@ +""" +experimental support for tensor-parallel inference with torchrun, +see https://github.com/vllm-project/vllm/issues/11400 for +the motivation and use case for this example. +run the script with `torchrun --nproc-per-node=2 torchrun_example.py`, +the argument 2 should match the `tensor_parallel_size` below. +see `tests/distributed/test_torchrun_example.py` for the unit test. +""" + +from vllm import LLM, SamplingParams + +# Create prompts, the same across all ranks +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create sampling parameters, the same across all ranks +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Use `distributed_executor_backend="external_launcher"` so that +# this llm engine/instance only creates one worker. +llm = LLM( + model="facebook/opt-125m", + tensor_parallel_size=2, + distributed_executor_backend="external_launcher", +) + +outputs = llm.generate(prompts, sampling_params) + +# all ranks will have the same outputs +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") +""" +Further tips: + +1. to communicate control messages across all ranks, use the cpu group, +a PyTorch ProcessGroup with GLOO backend. + +```python +from vllm.distributed.parallel_state import get_world_group +cpu_group = get_world_group().cpu_group +torch_rank = dist.get_rank(group=cpu_group) +if torch_rank == 0: + # do something for rank 0, e.g. saving the results to disk. +``` + +2. to communicate data across all ranks, use the model's device group, +a PyTorch ProcessGroup with NCCL backend. +```python +from vllm.distributed.parallel_state import get_world_group +device_group = get_world_group().device_group +``` + +3. to access the model directly in every rank, use the following code: +```python +llm.llm_engine.model_executor.driver_worker.worker.model_runner.model +``` +""" diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py new file mode 100644 index 00000000..7aa03d7f --- /dev/null +++ b/tests/distributed/test_torchrun_example.py @@ -0,0 +1,56 @@ +# unit test for `examples/offline_inference/torchrun_example.py` + +import random + +import torch.distributed as dist + +from vllm import LLM, SamplingParams +from vllm.distributed.parallel_state import get_world_group + +# Create prompts +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# set different `gpu_memory_utilization` and `swap_space` for different ranks, +# to test if all ranks agree on the same kv cache configuration. +llm = LLM(model="facebook/opt-125m", + tensor_parallel_size=2, + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4)) + +outputs = llm.generate(prompts, sampling_params) + +cpu_group = get_world_group().cpu_group + +torch_rank = dist.get_rank(group=cpu_group) + + +def test_consistent_across_ranks(obj): + if torch_rank == 0: + dist.broadcast_object_list([obj], src=0, group=cpu_group) + else: + container = [None] + dist.broadcast_object_list(container, src=0, group=cpu_group) + assert container[0] == obj + + +test_consistent_across_ranks( + llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks( + llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) + +# all ranks should have the same outputs +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + test_consistent_across_ranks(prompt) + test_consistent_across_ranks(generated_text) + print(f"Rank {torch_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py index db70a808..04505fca 100644 --- a/tests/engine/test_multiproc_workers.py +++ b/tests/engine/test_multiproc_workers.py @@ -22,7 +22,7 @@ class DummyWorkerWrapper(WorkerWrapperBase): # simulate error case raise worker_input - return self.rank, input + return self.rpc_rank, input def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: diff --git a/vllm/config.py b/vllm/config.py index 2fe674b8..a5f21610 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1338,14 +1338,15 @@ class ParallelConfig: from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform if self.distributed_executor_backend not in ( - "ray", "mp", "uni", None) and not (isinstance( + "ray", "mp", "uni", + "external_launcher", None) and not (isinstance( self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, ExecutorBase)): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " - "values are 'ray', 'mp' 'uni', or custom ExecutorBase" - " subclass.") + "values are 'ray', 'mp' 'uni', 'external_launcher' or" + " custom ExecutorBase subclass.") if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 03a8959a..a4f4c955 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -388,7 +388,7 @@ class EngineArgs: # Parallel arguments parser.add_argument( '--distributed-executor-backend', - choices=['ray', 'mp'], + choices=['ray', 'mp', 'uni', 'external_launcher'], default=EngineArgs.distributed_executor_backend, help='Backend to use for distributed model ' 'workers, either "ray" or "mp" (multiprocessing). If the product ' diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 49a1e9f5..5d19ce03 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -457,6 +457,11 @@ class LLMEngine: # JAX-style, single-process, multi-device executor. from vllm.executor.uniproc_executor import UniProcExecutor executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher else: from vllm.executor.uniproc_executor import UniProcExecutor executor_class = UniProcExecutor diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index edceece4..3baeb639 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -172,7 +172,7 @@ class RayDistributedExecutor(DistributedExecutorBase): scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rank=rank) + rpc_rank=rank) else: worker = ray.remote( num_cpus=0, @@ -181,7 +181,7 @@ class RayDistributedExecutor(DistributedExecutorBase): scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rank=rank) + rpc_rank=rank) worker_metadata.append( RayWorkerMetaData(worker=worker, created_rank=rank)) rank += 1 @@ -204,7 +204,7 @@ class RayDistributedExecutor(DistributedExecutorBase): # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config, rank=0) + vllm_config=self.vllm_config, rpc_rank=0) worker_metadata.pop(i) break diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index da1d7734..27b83e95 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -1,5 +1,10 @@ +import os from typing import Any, Dict, List, Optional, Tuple +import torch +import torch.distributed as dist + +import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port @@ -16,7 +21,7 @@ class UniProcExecutor(ExecutorBase): """Initialize the worker and load the model. """ self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rank=0) + rpc_rank=0) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) local_rank = 0 @@ -55,3 +60,77 @@ class UniProcExecutor(ExecutorBase): UniProcExecutorAsync = UniProcExecutor + + +class ExecutorWithExternalLauncher(UniProcExecutor): + """An executor that uses external launchers to launch engines, + specially designed for torchrun-compatible launchers, for + offline inference with tensor parallelism. + + see https://github.com/vllm-project/vllm/issues/11400 for + the motivation, and examples/offline_inference/torchrun_example.py + for the usage example. + + The key idea: although it is tensor-parallel inference, we only + create one worker per executor, users will launch multiple + engines with torchrun-compatible launchers, and all these engines + work together to process the same prompts. When scheduling is + deterministic, all the engines will generate the same outputs, + and they don't need to synchronize the states with each other. + """ + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \ + ("ExecutorWithExternalLauncher does not " + "support pipeline parallelism.") + assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ + ("ExecutorWithExternalLauncher needs deterministic " + "execution, so it" + "does not support delay_factor in scheduling") + assert not envs.VLLM_USE_V1, \ + ("V1 architecture cannot guarantee deterministic execution, " + "so it is not supported in ExecutorWithExternalLauncher.") + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + # engines are launched in torchrun-compatible launchers + # so we can use the env:// method. + # required env vars: + # - RANK + # - MASTER_ADDR + # - MASTER_PORT + distributed_init_method = "env://" + rank = int(os.environ["RANK"]) + local_rank = rank + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """ + Determine the number of available KV blocks. + Add an additional all_reduce to get the min across all ranks. + Note that even if we have the same `gpu_memory_utilization` and + `swap_space`, the available memory in every rank might still + differ because NCCL can take different amounts of memory in + different ranks. Therefore, it is necessary to test if all ranks + agree on the same KV cache configuration. + """ + a, b = super().determine_num_available_blocks() + from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group + a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) + b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) + dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return a_tensor.item(), b_tensor.item() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index dd981ffc..e6f26d2b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -940,8 +940,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): return self.base_layer.soft_cap @property - def use_gather(self): - return self.base_layer.use_gather + def use_all_gather(self): + return self.base_layer.use_all_gather @property def org_vocab_size(self): diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 2bc7e458..42decde1 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import vllm.envs as envs +from vllm.config import get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,8 +45,10 @@ class LogitsProcessor(nn.Module): self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_gather = not current_platform.is_tpu( - ) and not envs.VLLM_USE_V1 + parallel_config = get_current_vllm_config().parallel_config + self.use_all_gather = current_platform.is_tpu() \ + or envs.VLLM_USE_V1 \ + or parallel_config.distributed_executor_backend == "external_launcher" # noqa def forward( self, @@ -88,16 +91,17 @@ class LogitsProcessor(nn.Module): logits = lm_head.linear_method.apply(lm_head, hidden_states, bias=embedding_bias) - if self.use_gather: - # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) - else: + + if self.use_all_gather: # Gather is not supported for some devices such as TPUs. # Use all-gather instead. # NOTE(woosuk): Here, the outputs of every device should not be None # because XLA requires strict SPMD among all devices. Every device # should execute the same operations after gathering the logits. logits = tensor_model_parallel_all_gather(logits) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index cee0fcc0..e111ac7e 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -246,7 +246,7 @@ class WorkerProc: ready_path: str, ): self.rank = rank - wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank) + wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) # TODO: move `init_worker` to executor level as a collective rpc call all_kwargs: List[Dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a3e377ef..43eeb287 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -55,9 +55,6 @@ class Worker(LocalOrDistributedWorkerBase): self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker - if is_driver_worker: - assert rank % self.parallel_config.tensor_parallel_size == 0, \ - "Driver worker should be rank 0 of tensor parallel group." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7c14b834..d464b614 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -461,7 +461,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): class WorkerWrapperBase: """ - The whole point of this class is to lazily initialize the worker. + This class represents one process in an executor/engine. It is responsible + for lazily initializing the worker and handling the worker's lifecycle. We first instantiate the WorkerWrapper, which remembers the worker module and class name. Then, when we call `update_environment_variables`, and the real initialization happens in `init_worker`. @@ -470,9 +471,19 @@ class WorkerWrapperBase: def __init__( self, vllm_config: VllmConfig, - rank: int = 0, + rpc_rank: int = 0, ) -> None: - self.rank = rank + """ + Initialize the worker wrapper with the given vllm_config and rpc_rank. + Note: rpc_rank is the rank of the worker in the executor. In most cases, + it is also the rank of the worker in the distributed group. However, + when multiple executors work together, they can be different. + e.g. in the case of SPMD-style offline inference with TP=2, + users can launch 2 engines/executors, each with only 1 worker. + All workers have rpc_rank=0, but they have different ranks in the TP + group. + """ + self.rpc_rank = rpc_rank self.vllm_config = vllm_config self.worker: Optional[WorkerBase] = None if vllm_config.model_config is not None: @@ -485,16 +496,16 @@ class WorkerWrapperBase: def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: """ - Adjust the rank based on the given mapping. + Adjust the rpc_rank based on the given mapping. It is only used during the initialization of the executor, - to adjust the rank of workers after we create all workers. + to adjust the rpc_rank of workers after we create all workers. """ - if self.rank in rank_mapping: - self.rank = rank_mapping[self.rank] + if self.rpc_rank in rank_mapping: + self.rpc_rank = rank_mapping[self.rpc_rank] def update_environment_variables(self, envs_list: List[Dict[str, str]]) -> None: - envs = envs_list[self.rank] + envs = envs_list[self.rpc_rank] key = 'CUDA_VISIBLE_DEVICES' if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior @@ -507,7 +518,7 @@ class WorkerWrapperBase: Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ - kwargs = all_kwargs[self.rank] + kwargs = all_kwargs[self.rpc_rank] enable_trace_function_call_for_thread(self.vllm_config) # see https://github.com/NVIDIA/nccl/issues/1234