[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
81b2f4a45f
commit
151b08e0fe
@ -145,8 +145,10 @@ steps:
|
|||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||||
# TODO: create a dedicated test section for multi-GPU example tests
|
# TODO: create a dedicated test section for multi-GPU example tests
|
||||||
# when we have multiple distributed example tests
|
# when we have multiple distributed example tests
|
||||||
- python3 ../examples/offline_inference/rlhf.py
|
- pushd ../examples/offline_inference
|
||||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
- python3 rlhf.py
|
||||||
|
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||||
|
- popd
|
||||||
|
|
||||||
- label: Metrics, Tracing Test # 10min
|
- label: Metrics, Tracing Test # 10min
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
|
@ -18,72 +18,11 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
from ray.util.placement_group import placement_group
|
from ray.util.placement_group import placement_group
|
||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
from rlhf_utils import stateless_init_process_group
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.utils import get_ip, get_open_port
|
from vllm.utils import get_ip, get_open_port
|
||||||
from vllm.worker.worker import Worker
|
|
||||||
|
|
||||||
|
|
||||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
|
||||||
device):
|
|
||||||
"""
|
|
||||||
vLLM provides `StatelessProcessGroup` to create a process group
|
|
||||||
without considering the global process group in torch.distributed.
|
|
||||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
|
||||||
the data-plane communication (NCCL) between external (train processes)
|
|
||||||
and vLLM workers.
|
|
||||||
"""
|
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
|
||||||
pg = StatelessProcessGroup.create(host=master_address,
|
|
||||||
port=master_port,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size)
|
|
||||||
pynccl = PyNcclCommunicator(pg, device=device)
|
|
||||||
return pynccl
|
|
||||||
|
|
||||||
|
|
||||||
class MyWorker(Worker):
|
|
||||||
"""
|
|
||||||
The `MyWorker` class inherits from `Worker` to provide custom functions.
|
|
||||||
For simplicity, we define the `MyWorker` class in this self-contained
|
|
||||||
script. Normally, we should define the `MyWorker` class in a separate
|
|
||||||
file and pass the qualified name of the class to the `worker_cls`
|
|
||||||
parameter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def init_weight_update_group(self, master_address, master_port,
|
|
||||||
rank_offset, world_size):
|
|
||||||
from vllm.distributed.parallel_state import get_world_group
|
|
||||||
rank = get_world_group().rank + rank_offset
|
|
||||||
self.model_update_group = stateless_init_process_group(
|
|
||||||
master_address,
|
|
||||||
master_port,
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_weight(self, name, dtype, shape):
|
|
||||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
|
||||||
self.model_update_group.broadcast(weight,
|
|
||||||
src=0,
|
|
||||||
stream=torch.cuda.current_stream())
|
|
||||||
|
|
||||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
|
||||||
|
|
||||||
del weight
|
|
||||||
|
|
||||||
def check_weights_changed(self):
|
|
||||||
"""
|
|
||||||
Check if the weights are updated to 0.
|
|
||||||
"""
|
|
||||||
weights_updated = True
|
|
||||||
for name, p in self.model_runner.model.named_parameters():
|
|
||||||
weights_updated = weights_updated and torch.allclose(
|
|
||||||
p, torch.zeros_like(p))
|
|
||||||
return weights_updated
|
|
||||||
|
|
||||||
|
|
||||||
class MyLLM(LLM):
|
class MyLLM(LLM):
|
||||||
@ -129,7 +68,7 @@ llm = ray.remote(
|
|||||||
)(MyLLM).remote(
|
)(MyLLM).remote(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
worker_cls=MyWorker,
|
worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
distributed_executor_backend="ray",
|
distributed_executor_backend="ray",
|
||||||
)
|
)
|
||||||
@ -159,6 +98,7 @@ master_port = get_open_port()
|
|||||||
|
|
||||||
handle = llm.collective_rpc.remote("init_weight_update_group",
|
handle = llm.collective_rpc.remote("init_weight_update_group",
|
||||||
args=(master_address, master_port, 1, 3))
|
args=(master_address, master_port, 1, 3))
|
||||||
|
|
||||||
model_update_group = stateless_init_process_group(master_address, master_port,
|
model_update_group = stateless_init_process_group(master_address, master_port,
|
||||||
0, 3, torch.device("cuda:0"))
|
0, 3, torch.device("cuda:0"))
|
||||||
ray.get(handle)
|
ray.get(handle)
|
||||||
|
@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group
|
|||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.worker.worker import Worker
|
|
||||||
|
|
||||||
|
|
||||||
class MyWorker(Worker):
|
|
||||||
|
|
||||||
def report_device_id(self) -> str:
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
|
||||||
return self.device_uuid
|
|
||||||
|
|
||||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
|
||||||
handles = ipc_handles[self.device_uuid]
|
|
||||||
device_id = self.device.index
|
|
||||||
weights = []
|
|
||||||
for name, handle in handles.items():
|
|
||||||
func, args = handle
|
|
||||||
list_args = list(args)
|
|
||||||
# the key is to change device id to the current device id
|
|
||||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
|
||||||
list_args[6] = device_id
|
|
||||||
tensor = func(*list_args)
|
|
||||||
weights.append((name, tensor))
|
|
||||||
self.model_runner.model.load_weights(weights=weights)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def check_weights_changed(self):
|
|
||||||
"""
|
|
||||||
Check if the weights are updated to 0.
|
|
||||||
"""
|
|
||||||
weights_updated = True
|
|
||||||
for name, p in self.model_runner.model.named_parameters():
|
|
||||||
weights_updated = weights_updated and torch.allclose(
|
|
||||||
p, torch.zeros_like(p))
|
|
||||||
return weights_updated
|
|
||||||
|
|
||||||
|
|
||||||
class MyLLM(LLM):
|
class MyLLM(LLM):
|
||||||
@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
|
|||||||
)(MyLLM).remote(
|
)(MyLLM).remote(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
worker_cls=MyWorker,
|
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
distributed_executor_backend="ray",
|
distributed_executor_backend="ray",
|
||||||
gpu_memory_utilization=0.4,
|
gpu_memory_utilization=0.4,
|
||||||
|
105
examples/offline_inference/rlhf_utils.py
Normal file
105
examples/offline_inference/rlhf_utils.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||||
|
device):
|
||||||
|
"""
|
||||||
|
vLLM provides `StatelessProcessGroup` to create a process group
|
||||||
|
without considering the global process group in torch.distributed.
|
||||||
|
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||||
|
the data-plane communication (NCCL) between external (train processes)
|
||||||
|
and vLLM workers.
|
||||||
|
"""
|
||||||
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
|
pg = StatelessProcessGroup.create(host=master_address,
|
||||||
|
port=master_port,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size)
|
||||||
|
pynccl = PyNcclCommunicator(pg, device=device)
|
||||||
|
return pynccl
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerExtension:
|
||||||
|
"""
|
||||||
|
The class for vLLM's worker to inherit from.
|
||||||
|
By defining an extension class, the code can work no matter what is
|
||||||
|
the underlying worker class. This way, the code can be compatible
|
||||||
|
with both vLLM V0 and V1.
|
||||||
|
NOTE: we define this class in a separate module, and the main module
|
||||||
|
should pass the full qualified name as `worker_extension_cls` argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init_weight_update_group(self, master_address, master_port,
|
||||||
|
rank_offset, world_size):
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
rank = get_world_group().rank + rank_offset
|
||||||
|
self.model_update_group = stateless_init_process_group(
|
||||||
|
master_address,
|
||||||
|
master_port,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_weight(self, name, dtype, shape):
|
||||||
|
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||||
|
self.model_update_group.broadcast(weight,
|
||||||
|
src=0,
|
||||||
|
stream=torch.cuda.current_stream())
|
||||||
|
|
||||||
|
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||||
|
|
||||||
|
del weight
|
||||||
|
|
||||||
|
def check_weights_changed(self):
|
||||||
|
"""
|
||||||
|
Check if the weights are updated to 0.
|
||||||
|
"""
|
||||||
|
weights_updated = True
|
||||||
|
for name, p in self.model_runner.model.named_parameters():
|
||||||
|
weights_updated = weights_updated and torch.allclose(
|
||||||
|
p, torch.zeros_like(p))
|
||||||
|
return weights_updated
|
||||||
|
|
||||||
|
|
||||||
|
class ColocateWorkerExtension:
|
||||||
|
"""
|
||||||
|
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||||
|
By defining an extension class, the code can work no matter what is
|
||||||
|
the underlying worker class. This way, the code can be compatible
|
||||||
|
with both vLLM V0 and V1.
|
||||||
|
NOTE: we define this class in a separate module, and the main module
|
||||||
|
should pass the full qualified name as `worker_extension_cls` argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def report_device_id(self) -> str:
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||||
|
return self.device_uuid
|
||||||
|
|
||||||
|
def update_weights_from_ipc_handles(self, ipc_handles):
|
||||||
|
handles = ipc_handles[self.device_uuid]
|
||||||
|
device_id = self.device.index
|
||||||
|
weights = []
|
||||||
|
for name, handle in handles.items():
|
||||||
|
func, args = handle
|
||||||
|
list_args = list(args)
|
||||||
|
# the key is to change device id to the current device id
|
||||||
|
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||||
|
list_args[6] = device_id
|
||||||
|
tensor = func(*list_args)
|
||||||
|
weights.append((name, tensor))
|
||||||
|
self.model_runner.model.load_weights(weights=weights)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def check_weights_changed(self):
|
||||||
|
"""
|
||||||
|
Check if the weights are updated to 0.
|
||||||
|
"""
|
||||||
|
weights_updated = True
|
||||||
|
for name, p in self.model_runner.model.named_parameters():
|
||||||
|
weights_updated = weights_updated and torch.allclose(
|
||||||
|
p, torch.zeros_like(p))
|
||||||
|
return weights_updated
|
@ -1366,6 +1366,7 @@ class ParallelConfig:
|
|||||||
# will be determined based on the platform.
|
# will be determined based on the platform.
|
||||||
worker_cls: str = "auto"
|
worker_cls: str = "auto"
|
||||||
sd_worker_cls: str = "auto"
|
sd_worker_cls: str = "auto"
|
||||||
|
worker_extension_cls: str = ""
|
||||||
|
|
||||||
# world_size is TPxPP, it affects the number of workers we create.
|
# world_size is TPxPP, it affects the number of workers we create.
|
||||||
world_size: int = field(init=False)
|
world_size: int = field(init=False)
|
||||||
@ -1523,6 +1524,9 @@ class ParallelConfig:
|
|||||||
raise ValueError("Unable to use nsight profiling unless workers "
|
raise ValueError("Unable to use nsight profiling unless workers "
|
||||||
"run with Ray.")
|
"run with Ray.")
|
||||||
|
|
||||||
|
assert isinstance(self.worker_extension_cls, str), (
|
||||||
|
"worker_extension_cls must be a string (qualified class name).")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SchedulerConfig:
|
class SchedulerConfig:
|
||||||
|
@ -202,6 +202,7 @@ class EngineArgs:
|
|||||||
override_pooler_config: Optional[PoolerConfig] = None
|
override_pooler_config: Optional[PoolerConfig] = None
|
||||||
compilation_config: Optional[CompilationConfig] = None
|
compilation_config: Optional[CompilationConfig] = None
|
||||||
worker_cls: str = "auto"
|
worker_cls: str = "auto"
|
||||||
|
worker_extension_cls: str = ""
|
||||||
|
|
||||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
|
||||||
@ -1015,6 +1016,13 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
help='The worker class to use for distributed execution.')
|
help='The worker class to use for distributed execution.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--worker-extension-cls',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='The worker extension class on top of the worker cls, '
|
||||||
|
'it is useful if you just want to add new functions to the worker '
|
||||||
|
'class without changing the existing functions.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--generation-config",
|
"--generation-config",
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
@ -1209,6 +1217,7 @@ class EngineArgs:
|
|||||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||||
distributed_executor_backend=self.distributed_executor_backend,
|
distributed_executor_backend=self.distributed_executor_backend,
|
||||||
worker_cls=self.worker_cls,
|
worker_cls=self.worker_cls,
|
||||||
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_model_len = model_config.max_model_len
|
max_model_len = model_config.max_model_len
|
||||||
|
@ -558,10 +558,37 @@ class WorkerWrapperBase:
|
|||||||
worker_class = resolve_obj_by_qualname(
|
worker_class = resolve_obj_by_qualname(
|
||||||
self.vllm_config.parallel_config.worker_cls)
|
self.vllm_config.parallel_config.worker_cls)
|
||||||
else:
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"passing worker_cls as a class object is strongly deprecated,"
|
||||||
|
" as the serialization of class objects can be tricky and"
|
||||||
|
" error-prone. To be safe, please keep the class in a separate"
|
||||||
|
" module and pass the qualified name of the class as a string."
|
||||||
|
)
|
||||||
assert isinstance(self.vllm_config.parallel_config.worker_cls,
|
assert isinstance(self.vllm_config.parallel_config.worker_cls,
|
||||||
bytes)
|
bytes)
|
||||||
worker_class = cloudpickle.loads(
|
worker_class = cloudpickle.loads(
|
||||||
self.vllm_config.parallel_config.worker_cls)
|
self.vllm_config.parallel_config.worker_cls)
|
||||||
|
if self.vllm_config.parallel_config.worker_extension_cls:
|
||||||
|
worker_extension_cls = resolve_obj_by_qualname(
|
||||||
|
self.vllm_config.parallel_config.worker_extension_cls)
|
||||||
|
extended_calls = []
|
||||||
|
if worker_extension_cls not in worker_class.__bases__:
|
||||||
|
# check any conflicts between worker and worker_extension_cls
|
||||||
|
for attr in dir(worker_extension_cls):
|
||||||
|
if attr.startswith("__"):
|
||||||
|
continue
|
||||||
|
assert not hasattr(worker_class, attr), (
|
||||||
|
f"Worker class {worker_class} already has an attribute"
|
||||||
|
f" {attr}, which conflicts with the worker"
|
||||||
|
f" extension class {worker_extension_cls}.")
|
||||||
|
if callable(getattr(worker_extension_cls, attr)):
|
||||||
|
extended_calls.append(attr)
|
||||||
|
# dynamically inherit the worker extension class
|
||||||
|
worker_class.__bases__ = worker_class.__bases__ + (
|
||||||
|
worker_extension_cls, )
|
||||||
|
logger.info(
|
||||||
|
"Injected %s into %s for extended collective_rpc calls %s",
|
||||||
|
worker_extension_cls, worker_class, extended_calls)
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
# To make vLLM config available during worker initialization
|
# To make vLLM config available during worker initialization
|
||||||
self.worker = worker_class(**kwargs)
|
self.worker = worker_class(**kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user