[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
81b2f4a45f
commit
151b08e0fe
@ -145,8 +145,10 @@ steps:
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- python3 ../examples/offline_inference/rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
||||
- pushd ../examples/offline_inference
|
||||
- python3 rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
|
@ -18,72 +18,11 @@ import ray
|
||||
import torch
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from rlhf_utils import stateless_init_process_group
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||
device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
pg = StatelessProcessGroup.create(host=master_address,
|
||||
port=master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
"""
|
||||
The `MyWorker` class inherits from `Worker` to provide custom functions.
|
||||
For simplicity, we define the `MyWorker` class in this self-contained
|
||||
script. Normally, we should define the `MyWorker` class in a separate
|
||||
file and pass the qualified name of the class to the `worker_cls`
|
||||
parameter.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(self, master_address, master_port,
|
||||
rank_offset, world_size):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
master_port,
|
||||
rank,
|
||||
world_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(weight,
|
||||
src=0,
|
||||
stream=torch.cuda.current_stream())
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
del weight
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
@ -129,7 +68,7 @@ llm = ray.remote(
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_cls=MyWorker,
|
||||
worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
)
|
||||
@ -159,6 +98,7 @@ master_port = get_open_port()
|
||||
|
||||
handle = llm.collective_rpc.remote("init_weight_update_group",
|
||||
args=(master_address, master_port, 1, 3))
|
||||
|
||||
model_update_group = stateless_init_process_group(master_address, master_port,
|
||||
0, 3, torch.device("cuda:0"))
|
||||
ray.get(handle)
|
||||
|
@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
||||
handles = ipc_handles[self.device_uuid]
|
||||
device_id = self.device.index
|
||||
weights = []
|
||||
for name, handle in handles.items():
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
tensor = func(*list_args)
|
||||
weights.append((name, tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_cls=MyWorker,
|
||||
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.4,
|
||||
|
105
examples/offline_inference/rlhf_utils.py
Normal file
105
examples/offline_inference/rlhf_utils.py
Normal file
@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||
device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
pg = StatelessProcessGroup.create(host=master_address,
|
||||
port=master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
|
||||
class WorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from.
|
||||
By defining an extension class, the code can work no matter what is
|
||||
the underlying worker class. This way, the code can be compatible
|
||||
with both vLLM V0 and V1.
|
||||
NOTE: we define this class in a separate module, and the main module
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(self, master_address, master_port,
|
||||
rank_offset, world_size):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
master_port,
|
||||
rank,
|
||||
world_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(weight,
|
||||
src=0,
|
||||
stream=torch.cuda.current_stream())
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
del weight
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
class ColocateWorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||
By defining an extension class, the code can work no matter what is
|
||||
the underlying worker class. This way, the code can be compatible
|
||||
with both vLLM V0 and V1.
|
||||
NOTE: we define this class in a separate module, and the main module
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
||||
handles = ipc_handles[self.device_uuid]
|
||||
device_id = self.device.index
|
||||
weights = []
|
||||
for name, handle in handles.items():
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
tensor = func(*list_args)
|
||||
weights.append((name, tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
@ -1366,6 +1366,7 @@ class ParallelConfig:
|
||||
# will be determined based on the platform.
|
||||
worker_cls: str = "auto"
|
||||
sd_worker_cls: str = "auto"
|
||||
worker_extension_cls: str = ""
|
||||
|
||||
# world_size is TPxPP, it affects the number of workers we create.
|
||||
world_size: int = field(init=False)
|
||||
@ -1523,6 +1524,9 @@ class ParallelConfig:
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
"run with Ray.")
|
||||
|
||||
assert isinstance(self.worker_extension_cls, str), (
|
||||
"worker_extension_cls must be a string (qualified class name).")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerConfig:
|
||||
|
@ -202,6 +202,7 @@ class EngineArgs:
|
||||
override_pooler_config: Optional[PoolerConfig] = None
|
||||
compilation_config: Optional[CompilationConfig] = None
|
||||
worker_cls: str = "auto"
|
||||
worker_extension_cls: str = ""
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
@ -1015,6 +1016,13 @@ class EngineArgs:
|
||||
type=str,
|
||||
default="auto",
|
||||
help='The worker class to use for distributed execution.')
|
||||
parser.add_argument(
|
||||
'--worker-extension-cls',
|
||||
type=str,
|
||||
default="",
|
||||
help='The worker extension class on top of the worker cls, '
|
||||
'it is useful if you just want to add new functions to the worker '
|
||||
'class without changing the existing functions.')
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=nullable_str,
|
||||
@ -1209,6 +1217,7 @@ class EngineArgs:
|
||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
)
|
||||
|
||||
max_model_len = model_config.max_model_len
|
||||
|
@ -558,10 +558,37 @@ class WorkerWrapperBase:
|
||||
worker_class = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
else:
|
||||
logger.warning(
|
||||
"passing worker_cls as a class object is strongly deprecated,"
|
||||
" as the serialization of class objects can be tricky and"
|
||||
" error-prone. To be safe, please keep the class in a separate"
|
||||
" module and pass the qualified name of the class as a string."
|
||||
)
|
||||
assert isinstance(self.vllm_config.parallel_config.worker_cls,
|
||||
bytes)
|
||||
worker_class = cloudpickle.loads(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
if self.vllm_config.parallel_config.worker_extension_cls:
|
||||
worker_extension_cls = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_extension_cls)
|
||||
extended_calls = []
|
||||
if worker_extension_cls not in worker_class.__bases__:
|
||||
# check any conflicts between worker and worker_extension_cls
|
||||
for attr in dir(worker_extension_cls):
|
||||
if attr.startswith("__"):
|
||||
continue
|
||||
assert not hasattr(worker_class, attr), (
|
||||
f"Worker class {worker_class} already has an attribute"
|
||||
f" {attr}, which conflicts with the worker"
|
||||
f" extension class {worker_extension_cls}.")
|
||||
if callable(getattr(worker_extension_cls, attr)):
|
||||
extended_calls.append(attr)
|
||||
# dynamically inherit the worker extension class
|
||||
worker_class.__bases__ = worker_class.__bases__ + (
|
||||
worker_extension_cls, )
|
||||
logger.info(
|
||||
"Injected %s into %s for extended collective_rpc calls %s",
|
||||
worker_extension_cls, worker_class, extended_calls)
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
self.worker = worker_class(**kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user