106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
![]() |
# SPDX-License-Identifier: Apache-2.0
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||
|
device):
|
||
|
"""
|
||
|
vLLM provides `StatelessProcessGroup` to create a process group
|
||
|
without considering the global process group in torch.distributed.
|
||
|
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||
|
the data-plane communication (NCCL) between external (train processes)
|
||
|
and vLLM workers.
|
||
|
"""
|
||
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||
|
pg = StatelessProcessGroup.create(host=master_address,
|
||
|
port=master_port,
|
||
|
rank=rank,
|
||
|
world_size=world_size)
|
||
|
pynccl = PyNcclCommunicator(pg, device=device)
|
||
|
return pynccl
|
||
|
|
||
|
|
||
|
class WorkerExtension:
|
||
|
"""
|
||
|
The class for vLLM's worker to inherit from.
|
||
|
By defining an extension class, the code can work no matter what is
|
||
|
the underlying worker class. This way, the code can be compatible
|
||
|
with both vLLM V0 and V1.
|
||
|
NOTE: we define this class in a separate module, and the main module
|
||
|
should pass the full qualified name as `worker_extension_cls` argument.
|
||
|
"""
|
||
|
|
||
|
def init_weight_update_group(self, master_address, master_port,
|
||
|
rank_offset, world_size):
|
||
|
from vllm.distributed.parallel_state import get_world_group
|
||
|
rank = get_world_group().rank + rank_offset
|
||
|
self.model_update_group = stateless_init_process_group(
|
||
|
master_address,
|
||
|
master_port,
|
||
|
rank,
|
||
|
world_size,
|
||
|
self.device,
|
||
|
)
|
||
|
|
||
|
def update_weight(self, name, dtype, shape):
|
||
|
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||
|
self.model_update_group.broadcast(weight,
|
||
|
src=0,
|
||
|
stream=torch.cuda.current_stream())
|
||
|
|
||
|
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||
|
|
||
|
del weight
|
||
|
|
||
|
def check_weights_changed(self):
|
||
|
"""
|
||
|
Check if the weights are updated to 0.
|
||
|
"""
|
||
|
weights_updated = True
|
||
|
for name, p in self.model_runner.model.named_parameters():
|
||
|
weights_updated = weights_updated and torch.allclose(
|
||
|
p, torch.zeros_like(p))
|
||
|
return weights_updated
|
||
|
|
||
|
|
||
|
class ColocateWorkerExtension:
|
||
|
"""
|
||
|
The class for vLLM's worker to inherit from, in the colocate setting.
|
||
|
By defining an extension class, the code can work no matter what is
|
||
|
the underlying worker class. This way, the code can be compatible
|
||
|
with both vLLM V0 and V1.
|
||
|
NOTE: we define this class in a separate module, and the main module
|
||
|
should pass the full qualified name as `worker_extension_cls` argument.
|
||
|
"""
|
||
|
|
||
|
def report_device_id(self) -> str:
|
||
|
from vllm.platforms import current_platform
|
||
|
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||
|
return self.device_uuid
|
||
|
|
||
|
def update_weights_from_ipc_handles(self, ipc_handles):
|
||
|
handles = ipc_handles[self.device_uuid]
|
||
|
device_id = self.device.index
|
||
|
weights = []
|
||
|
for name, handle in handles.items():
|
||
|
func, args = handle
|
||
|
list_args = list(args)
|
||
|
# the key is to change device id to the current device id
|
||
|
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||
|
list_args[6] = device_id
|
||
|
tensor = func(*list_args)
|
||
|
weights.append((name, tensor))
|
||
|
self.model_runner.model.load_weights(weights=weights)
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
def check_weights_changed(self):
|
||
|
"""
|
||
|
Check if the weights are updated to 0.
|
||
|
"""
|
||
|
weights_updated = True
|
||
|
for name, p in self.model_runner.model.named_parameters():
|
||
|
weights_updated = weights_updated and torch.allclose(
|
||
|
p, torch.zeros_like(p))
|
||
|
return weights_updated
|