[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-03-07 00:32:46 +08:00 committed by GitHub
parent 81b2f4a45f
commit 151b08e0fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 153 additions and 100 deletions

View File

@ -145,8 +145,10 @@ steps:
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- pushd ../examples/offline_inference
- python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
- label: Metrics, Tracing Test # 10min
num_gpus: 2

View File

@ -18,72 +18,11 @@ import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker
def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
class MyWorker(Worker):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""
def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)
def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class MyLLM(LLM):
@ -129,7 +68,7 @@ llm = ray.remote(
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
@ -159,6 +98,7 @@ master_port = get_open_port()
handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))
model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)

View File

@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm import LLM
from vllm.worker.worker import Worker
class MyWorker(Worker):
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class MyLLM(LLM):
@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,

View File

@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
import torch
def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
class WorkerExtension:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)
def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated

View File

@ -1366,6 +1366,7 @@ class ParallelConfig:
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
worker_extension_cls: str = ""
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
@ -1523,6 +1524,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")
@dataclass
class SchedulerConfig:

View File

@ -202,6 +202,7 @@ class EngineArgs:
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
worker_extension_cls: str = ""
kv_transfer_config: Optional[KVTransferConfig] = None
@ -1015,6 +1016,13 @@ class EngineArgs:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
'--worker-extension-cls',
type=str,
default="",
help='The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
@ -1209,6 +1217,7 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
)
max_model_len = model_config.max_model_len

View File

@ -558,10 +558,37 @@ class WorkerWrapperBase:
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
logger.warning(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}.")
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls, )
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)