[core] LLM.collective_rpc interface and RLHF example (#12084)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-16 20:19:52 +08:00 committed by GitHub
parent bf53e0c70b
commit 92e793d91a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 270 additions and 35 deletions

View File

@ -126,11 +126,15 @@ steps:
- tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
- examples/offline_inference/rlhf.py
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- label: Metrics, Tracing Test # 10min
num_gpus: 2

View File

@ -0,0 +1,191 @@
"""
a simple demonstration of RLHF with vLLM, inspired by
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes
are different, and they live on different GPUs.
Training processes send prompts to inference processes to generate data,
and also synchronize the weights of the model by broadcasting the weights
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework.
"""
import os
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams, configure_as_vllm_process
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker
def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
class MyWorker(Worker):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""
def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)
def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class MyLLM(LLM):
def __init__(self, *args, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
del os.environ["CUDA_VISIBLE_DEVICES"]
super().__init__(*args, **kwargs)
"""
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
It is important for all the processes outside of vLLM to call
`configure_as_vllm_process` to set some common environment variables
the same as vLLM workers.
"""
configure_as_vllm_process()
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and
GPU 2. For the details on how to use ray, please refer to the ray
documentation https://docs.ray.io/en/latest/ .
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
"""
launch the vLLM inference engine.
here we use `enforce_eager` to reduce the start time.
"""
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
# Generate texts from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
# set up the communication between the training process
# and the inference engine.
master_address = get_ip()
master_port = get_open_port()
handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))
model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)
# simulate training, modify the weights of the model.
for name, p in train_model.named_parameters():
p.data.zero_()
# sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight",
args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)
# check if the weights are updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# use the updated model to generate texts, they will be nonsense
# because the weights are all zeros.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

View File

@ -17,6 +17,44 @@ from vllm.sampling_params import SamplingParams
from .version import __version__, __version_tuple__
def configure_as_vllm_process():
"""
set some common config/environment variables that should be set
for all processes created by vllm and all processes
that interact with vllm workers.
"""
import os
import torch
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
from vllm.platforms import current_platform
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
__all__ = [
"__version__",
"__version_tuple__",
@ -42,4 +80,5 @@ __all__ = [
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"configure_as_vllm_process",
]

View File

@ -4,6 +4,7 @@ from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
import cloudpickle
from tqdm import tqdm
from typing_extensions import deprecated
@ -186,6 +187,13 @@ class LLM:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
@ -455,6 +463,23 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)
def beam_search(
self,
prompts: List[Union[TokensPrompt, TextPrompt]],

View File

@ -1,9 +1,6 @@
import logging
import os
from typing import Callable, Dict
import torch
import vllm.envs as envs
logger = logging.getLogger(__name__)
@ -50,34 +47,6 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""
# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
from vllm.platforms import current_platform
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
if current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
global plugins_loaded
if plugins_loaded:
return

View File

@ -4,6 +4,7 @@ import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle
import torch
from vllm.config import ObservabilityConfig, VllmConfig
@ -521,14 +522,20 @@ class WorkerWrapperBase:
kwargs = all_kwargs[self.rpc_rank]
enable_trace_function_call_for_thread(self.vllm_config)
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
from vllm import configure_as_vllm_process
configure_as_vllm_process()
from vllm.plugins import load_general_plugins
load_general_plugins()
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(**kwargs)
assert self.worker is not None