Support torchrun and SPMD-style offline inference (#12071)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
dd7c9ad870
commit
bf53e0c70b
@ -463,6 +463,7 @@ steps:
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
commands:
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
|
64
examples/offline_inference/torchrun_example.py
Normal file
64
examples/offline_inference/torchrun_example.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
experimental support for tensor-parallel inference with torchrun,
|
||||
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||
the motivation and use case for this example.
|
||||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
|
||||
the argument 2 should match the `tensor_parallel_size` below.
|
||||
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create sampling parameters, the same across all ranks
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Use `distributed_executor_backend="external_launcher"` so that
|
||||
# this llm engine/instance only creates one worker.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="external_launcher",
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# all ranks will have the same outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
"""
|
||||
Further tips:
|
||||
|
||||
1. to communicate control messages across all ranks, use the cpu group,
|
||||
a PyTorch ProcessGroup with GLOO backend.
|
||||
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
if torch_rank == 0:
|
||||
# do something for rank 0, e.g. saving the results to disk.
|
||||
```
|
||||
|
||||
2. to communicate data across all ranks, use the model's device group,
|
||||
a PyTorch ProcessGroup with NCCL backend.
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
device_group = get_world_group().device_group
|
||||
```
|
||||
|
||||
3. to access the model directly in every rank, use the following code:
|
||||
```python
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
```
|
||||
"""
|
56
tests/distributed/test_torchrun_example.py
Normal file
56
tests/distributed/test_torchrun_example.py
Normal file
@ -0,0 +1,56 @@
|
||||
# unit test for `examples/offline_inference/torchrun_example.py`
|
||||
|
||||
import random
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||
# to test if all ranks agree on the same kv cache configuration.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4))
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
cpu_group = get_world_group().cpu_group
|
||||
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
|
||||
|
||||
def test_consistent_across_ranks(obj):
|
||||
if torch_rank == 0:
|
||||
dist.broadcast_object_list([obj], src=0, group=cpu_group)
|
||||
else:
|
||||
container = [None]
|
||||
dist.broadcast_object_list(container, src=0, group=cpu_group)
|
||||
assert container[0] == obj
|
||||
|
||||
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||
test_consistent_across_ranks(
|
||||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
# all ranks should have the same outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
test_consistent_across_ranks(prompt)
|
||||
test_consistent_across_ranks(generated_text)
|
||||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
@ -22,7 +22,7 @@ class DummyWorkerWrapper(WorkerWrapperBase):
|
||||
# simulate error case
|
||||
raise worker_input
|
||||
|
||||
return self.rank, input
|
||||
return self.rpc_rank, input
|
||||
|
||||
|
||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
|
@ -1338,14 +1338,15 @@ class ParallelConfig:
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.platforms import current_platform
|
||||
if self.distributed_executor_backend not in (
|
||||
"ray", "mp", "uni", None) and not (isinstance(
|
||||
"ray", "mp", "uni",
|
||||
"external_launcher", None) and not (isinstance(
|
||||
self.distributed_executor_backend, type) and issubclass(
|
||||
self.distributed_executor_backend, ExecutorBase)):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
|
||||
" subclass.")
|
||||
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
|
||||
" custom ExecutorBase subclass.")
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
|
@ -388,7 +388,7 @@ class EngineArgs:
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
'--distributed-executor-backend',
|
||||
choices=['ray', 'mp'],
|
||||
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
||||
default=EngineArgs.distributed_executor_backend,
|
||||
help='Backend to use for distributed model '
|
||||
'workers, either "ray" or "mp" (multiprocessing). If the product '
|
||||
|
@ -457,6 +457,11 @@ class LLMEngine:
|
||||
# JAX-style, single-process, multi-device executor.
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
executor_class = UniProcExecutor
|
||||
elif distributed_executor_backend == "external_launcher":
|
||||
# executor with external launcher
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher)
|
||||
executor_class = ExecutorWithExternalLauncher
|
||||
else:
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
executor_class = UniProcExecutor
|
||||
|
@ -172,7 +172,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rank=rank)
|
||||
rpc_rank=rank)
|
||||
else:
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
@ -181,7 +181,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rank=rank)
|
||||
rpc_rank=rank)
|
||||
worker_metadata.append(
|
||||
RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
rank += 1
|
||||
@ -204,7 +204,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config, rank=0)
|
||||
vllm_config=self.vllm_config, rpc_rank=0)
|
||||
worker_metadata.pop(i)
|
||||
break
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
@ -16,7 +21,7 @@ class UniProcExecutor(ExecutorBase):
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||
rank=0)
|
||||
rpc_rank=0)
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
local_rank = 0
|
||||
@ -55,3 +60,77 @@ class UniProcExecutor(ExecutorBase):
|
||||
|
||||
|
||||
UniProcExecutorAsync = UniProcExecutor
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
"""An executor that uses external launchers to launch engines,
|
||||
specially designed for torchrun-compatible launchers, for
|
||||
offline inference with tensor parallelism.
|
||||
|
||||
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||
the motivation, and examples/offline_inference/torchrun_example.py
|
||||
for the usage example.
|
||||
|
||||
The key idea: although it is tensor-parallel inference, we only
|
||||
create one worker per executor, users will launch multiple
|
||||
engines with torchrun-compatible launchers, and all these engines
|
||||
work together to process the same prompts. When scheduling is
|
||||
deterministic, all the engines will generate the same outputs,
|
||||
and they don't need to synchronize the states with each other.
|
||||
"""
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
|
||||
("ExecutorWithExternalLauncher does not "
|
||||
"support pipeline parallelism.")
|
||||
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
|
||||
("ExecutorWithExternalLauncher needs deterministic "
|
||||
"execution, so it"
|
||||
"does not support delay_factor in scheduling")
|
||||
assert not envs.VLLM_USE_V1, \
|
||||
("V1 architecture cannot guarantee deterministic execution, "
|
||||
"so it is not supported in ExecutorWithExternalLauncher.")
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||
rpc_rank=0)
|
||||
# engines are launched in torchrun-compatible launchers
|
||||
# so we can use the env:// method.
|
||||
# required env vars:
|
||||
# - RANK
|
||||
# - MASTER_ADDR
|
||||
# - MASTER_PORT
|
||||
distributed_init_method = "env://"
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = rank
|
||||
is_driver_worker = True
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Determine the number of available KV blocks.
|
||||
Add an additional all_reduce to get the min across all ranks.
|
||||
Note that even if we have the same `gpu_memory_utilization` and
|
||||
`swap_space`, the available memory in every rank might still
|
||||
differ because NCCL can take different amounts of memory in
|
||||
different ranks. Therefore, it is necessary to test if all ranks
|
||||
agree on the same KV cache configuration.
|
||||
"""
|
||||
a, b = super().determine_num_available_blocks()
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
|
||||
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return a_tensor.item(), b_tensor.item()
|
||||
|
@ -940,8 +940,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
return self.base_layer.soft_cap
|
||||
|
||||
@property
|
||||
def use_gather(self):
|
||||
return self.base_layer.use_gather
|
||||
def use_all_gather(self):
|
||||
return self.base_layer.use_all_gather
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -44,8 +45,10 @@ class LogitsProcessor(nn.Module):
|
||||
self.soft_cap = soft_cap
|
||||
# Whether to use gather or all-gather to gather the logits.
|
||||
|
||||
self.use_gather = not current_platform.is_tpu(
|
||||
) and not envs.VLLM_USE_V1
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
self.use_all_gather = current_platform.is_tpu() \
|
||||
or envs.VLLM_USE_V1 \
|
||||
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -88,16 +91,17 @@ class LogitsProcessor(nn.Module):
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
if self.use_gather:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
else:
|
||||
|
||||
if self.use_all_gather:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
else:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
|
@ -246,7 +246,7 @@ class WorkerProc:
|
||||
ready_path: str,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank)
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: List[Dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
|
@ -55,9 +55,6 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if is_driver_worker:
|
||||
assert rank % self.parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
|
@ -461,7 +461,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
The whole point of this class is to lazily initialize the worker.
|
||||
This class represents one process in an executor/engine. It is responsible
|
||||
for lazily initializing the worker and handling the worker's lifecycle.
|
||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||
and class name. Then, when we call `update_environment_variables`, and the
|
||||
real initialization happens in `init_worker`.
|
||||
@ -470,9 +471,19 @@ class WorkerWrapperBase:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int = 0,
|
||||
rpc_rank: int = 0,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
"""
|
||||
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||
Note: rpc_rank is the rank of the worker in the executor. In most cases,
|
||||
it is also the rank of the worker in the distributed group. However,
|
||||
when multiple executors work together, they can be different.
|
||||
e.g. in the case of SPMD-style offline inference with TP=2,
|
||||
users can launch 2 engines/executors, each with only 1 worker.
|
||||
All workers have rpc_rank=0, but they have different ranks in the TP
|
||||
group.
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.vllm_config = vllm_config
|
||||
self.worker: Optional[WorkerBase] = None
|
||||
if vllm_config.model_config is not None:
|
||||
@ -485,16 +496,16 @@ class WorkerWrapperBase:
|
||||
|
||||
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rank based on the given mapping.
|
||||
Adjust the rpc_rank based on the given mapping.
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rank of workers after we create all workers.
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
if self.rank in rank_mapping:
|
||||
self.rank = rank_mapping[self.rank]
|
||||
if self.rpc_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
|
||||
def update_environment_variables(self, envs_list: List[Dict[str,
|
||||
str]]) -> None:
|
||||
envs = envs_list[self.rank]
|
||||
envs = envs_list[self.rpc_rank]
|
||||
key = 'CUDA_VISIBLE_DEVICES'
|
||||
if key in envs and key in os.environ:
|
||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||
@ -507,7 +518,7 @@ class WorkerWrapperBase:
|
||||
Here we inject some common logic before initializing the worker.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
kwargs = all_kwargs[self.rank]
|
||||
kwargs = all_kwargs[self.rpc_rank]
|
||||
enable_trace_function_call_for_thread(self.vllm_config)
|
||||
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
|
Loading…
x
Reference in New Issue
Block a user