diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 521faeed..ef05cb99 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -145,8 +145,10 @@ steps: - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - python3 ../examples/offline_inference/rlhf.py - - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py + - pushd ../examples/offline_inference + - python3 rlhf.py + - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - popd - label: Metrics, Tracing Test # 10min num_gpus: 2 diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 172d18cb..b0418c09 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -18,72 +18,11 @@ import ray import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rlhf_utils import stateless_init_process_group from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams from vllm.utils import get_ip, get_open_port -from vllm.worker.worker import Worker - - -def stateless_init_process_group(master_address, master_port, rank, world_size, - device): - """ - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) - and vLLM workers. - """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - pg = StatelessProcessGroup.create(host=master_address, - port=master_port, - rank=rank, - world_size=world_size) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl - - -class MyWorker(Worker): - """ - The `MyWorker` class inherits from `Worker` to provide custom functions. - For simplicity, we define the `MyWorker` class in this self-contained - script. Normally, we should define the `MyWorker` class in a separate - file and pass the qualified name of the class to the `worker_cls` - parameter. - """ - - def init_weight_update_group(self, master_address, master_port, - rank_offset, world_size): - from vllm.distributed.parallel_state import get_world_group - rank = get_world_group().rank + rank_offset - self.model_update_group = stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - self.device, - ) - - def update_weight(self, name, dtype, shape): - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast(weight, - src=0, - stream=torch.cuda.current_stream()) - - self.model_runner.model.load_weights(weights=[(name, weight)]) - - del weight - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) - return weights_updated class MyLLM(LLM): @@ -129,7 +68,7 @@ llm = ray.remote( )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_cls=MyWorker, + worker_extension_cls="rlhf_utils.WorkerExtension", tensor_parallel_size=2, distributed_executor_backend="ray", ) @@ -159,6 +98,7 @@ master_port = get_open_port() handle = llm.collective_rpc.remote("init_weight_update_group", args=(master_address, master_port, 1, 3)) + model_update_group = stateless_init_process_group(master_address, master_port, 0, 3, torch.device("cuda:0")) ray.get(handle) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 15dc7edc..3ceac0fa 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM -from vllm.worker.worker import Worker - - -class MyWorker(Worker): - - def report_device_id(self) -> str: - from vllm.platforms import current_platform - self.device_uuid = current_platform.get_device_uuid(self.device.index) - return self.device_uuid - - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) - return weights_updated class MyLLM(LLM): @@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_cls=MyWorker, + worker_extension_cls="rlhf_utils.ColocateWorkerExtension", tensor_parallel_size=2, distributed_executor_backend="ray", gpu_memory_utilization=0.4, diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py new file mode 100644 index 00000000..11b73b7c --- /dev/null +++ b/examples/offline_inference/rlhf_utils.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def stateless_init_process_group(master_address, master_port, rank, world_size, + device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + pg = StatelessProcessGroup.create(host=master_address, + port=master_port, + rank=rank, + world_size=world_size) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +class WorkerExtension: + """ + The class for vLLM's worker to inherit from. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + """ + + def init_weight_update_group(self, master_address, master_port, + rank_offset, world_size): + from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset + self.model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + + def update_weight(self, name, dtype, shape): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, + src=0, + stream=torch.cuda.current_stream()) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + +class ColocateWorkerExtension: + """ + The class for vLLM's worker to inherit from, in the colocate setting. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + """ + + def report_device_id(self) -> str: + from vllm.platforms import current_platform + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + handles = ipc_handles[self.device_uuid] + device_id = self.device.index + weights = [] + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated diff --git a/vllm/config.py b/vllm/config.py index 3f1bff49..9b84d040 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1366,6 +1366,7 @@ class ParallelConfig: # will be determined based on the platform. worker_cls: str = "auto" sd_worker_cls: str = "auto" + worker_extension_cls: str = "" # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) @@ -1523,6 +1524,9 @@ class ParallelConfig: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") + assert isinstance(self.worker_extension_cls, str), ( + "worker_extension_cls must be a string (qualified class name).") + @dataclass class SchedulerConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 989eb4db..d033acff 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,6 +202,7 @@ class EngineArgs: override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None worker_cls: str = "auto" + worker_extension_cls: str = "" kv_transfer_config: Optional[KVTransferConfig] = None @@ -1015,6 +1016,13 @@ class EngineArgs: type=str, default="auto", help='The worker class to use for distributed execution.') + parser.add_argument( + '--worker-extension-cls', + type=str, + default="", + help='The worker extension class on top of the worker cls, ' + 'it is useful if you just want to add new functions to the worker ' + 'class without changing the existing functions.') parser.add_argument( "--generation-config", type=nullable_str, @@ -1209,6 +1217,7 @@ class EngineArgs: ray_workers_use_nsight=self.ray_workers_use_nsight, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, + worker_extension_cls=self.worker_extension_cls, ) max_model_len = model_config.max_model_len diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7cc1562a..e5662e69 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -558,10 +558,37 @@ class WorkerWrapperBase: worker_class = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_cls) else: + logger.warning( + "passing worker_cls as a class object is strongly deprecated," + " as the serialization of class objects can be tricky and" + " error-prone. To be safe, please keep the class in a separate" + " module and pass the qualified name of the class as a string." + ) assert isinstance(self.vllm_config.parallel_config.worker_cls, bytes) worker_class = cloudpickle.loads( self.vllm_config.parallel_config.worker_cls) + if self.vllm_config.parallel_config.worker_extension_cls: + worker_extension_cls = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_extension_cls) + extended_calls = [] + if worker_extension_cls not in worker_class.__bases__: + # check any conflicts between worker and worker_extension_cls + for attr in dir(worker_extension_cls): + if attr.startswith("__"): + continue + assert not hasattr(worker_class, attr), ( + f"Worker class {worker_class} already has an attribute" + f" {attr}, which conflicts with the worker" + f" extension class {worker_extension_cls}.") + if callable(getattr(worker_extension_cls, attr)): + extended_calls.append(attr) + # dynamically inherit the worker extension class + worker_class.__bases__ = worker_class.__bases__ + ( + worker_extension_cls, ) + logger.info( + "Injected %s into %s for extended collective_rpc calls %s", + worker_extension_cls, worker_class, extended_calls) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs)